From 59fdb3df67ebdfcca48fa86677c77031aefd72fd Mon Sep 17 00:00:00 2001 From: Rick Parker Date: Fri, 25 May 2018 13:26:00 +0100 Subject: [PATCH] CORDA-1475 CORDA-1465 Allow flows to retry from last checkpoint (#3204) --- .../corda/core/internal/FlowStateMachine.kt | 1 + .../internal/persistence/CordaPersistence.kt | 17 +- .../persistence/DatabaseTransaction.kt | 16 +- .../net/corda/node/flows/FlowRetryTest.kt | 158 ++++++ .../events/ScheduledFlowIntegrationTests.kt | 6 +- .../net/corda/node/internal/AbstractNode.kt | 33 +- .../node/services/api/CheckpointStorage.kt | 6 + .../node/services/api/ServiceHubInternal.kt | 14 +- .../services/events/NodeSchedulerService.kt | 71 ++- .../node/services/messaging/Messaging.kt | 28 +- .../messaging/P2PMessageDeduplicator.kt | 3 - .../services/messaging/P2PMessagingClient.kt | 25 +- .../persistence/DBCheckpointStorage.kt | 7 +- .../persistence/DBTransactionStorage.kt | 7 +- .../node/services/statemachine/Action.kt | 12 +- .../statemachine/ActionExecutorImpl.kt | 9 +- .../services/statemachine/DeduplicationId.kt | 6 + .../corda/node/services/statemachine/Event.kt | 19 +- .../services/statemachine/FlowHospital.kt | 7 +- .../services/statemachine/FlowMessaging.kt | 4 +- .../statemachine/FlowStateMachineImpl.kt | 15 +- .../statemachine/PropagatingFlowHospital.kt | 9 +- .../SingleThreadedStateMachineManager.kt | 105 +++- .../statemachine/StaffedFlowHospital.kt | 127 +++++ .../statemachine/StateMachineManager.kt | 57 ++- .../statemachine/StateMachineState.kt | 4 +- .../DumpHistoryOnErrorInterceptor.kt | 3 +- .../interceptors/HospitalisingInterceptor.kt | 23 +- .../DeliverSessionMessageTransition.kt | 20 +- .../transitions/ErrorFlowTransition.kt | 2 +- .../transitions/StartedFlowTransition.kt | 6 +- .../transitions/TopLevelTransition.kt | 37 +- .../transitions/UnstartedFlowTransition.kt | 2 +- .../node/utilities/AppendOnlyPersistentMap.kt | 287 +++++++++-- .../events/NodeSchedulerServiceTest.kt | 33 +- .../AppendOnlyPersistentMapTest.kt | 290 +++++++++++ .../persistence/TransactionCallbackTest.kt | 49 ++ .../statemachine/RetryFlowMockTest.kt | 166 ++++++ .../node/services/vault/VaultQueryTests.kt | 483 +++++++++--------- .../corda/node/utilities/ObservablesTests.kt | 69 ++- .../testing/node/InMemoryMessagingNetwork.kt | 76 ++- 41 files changed, 1843 insertions(+), 469 deletions(-) create mode 100644 node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt create mode 100644 node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt create mode 100644 node/src/test/kotlin/net/corda/node/services/persistence/TransactionCallbackTest.kt create mode 100644 node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index 36002556c8..fc6a5a529e 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -41,4 +41,5 @@ interface FlowStateMachine { val resultFuture: CordaFuture val context: InvocationContext val ourIdentity: Party + val ourSenderUUID: String? } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt index 31ca0cf166..dbec43d561 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt @@ -5,7 +5,6 @@ import net.corda.core.schemas.MappedSchema import net.corda.core.utilities.contextLogger import rx.Observable import rx.Subscriber -import rx.subjects.PublishSubject import rx.subjects.UnicastSubject import java.io.Closeable import java.sql.Connection @@ -67,9 +66,7 @@ class CordaPersistence( } val entityManagerFactory get() = hibernateConfig.sessionFactoryForRegisteredSchemas - data class Boundary(val txId: UUID) - - internal val transactionBoundaries = PublishSubject.create().toSerialized() + data class Boundary(val txId: UUID, val success: Boolean) init { // Found a unit test that was forgetting to close the database transactions. When you close() on the top level @@ -186,15 +183,19 @@ class CordaPersistence( * * For examples, see the call hierarchy of this function. */ -fun rx.Observer.bufferUntilDatabaseCommit(): rx.Observer { - val currentTxId = contextTransaction.id - val databaseTxBoundary: Observable = contextDatabase.transactionBoundaries.first { it.txId == currentTxId } +fun rx.Observer.bufferUntilDatabaseCommit(propagateRollbackAsError: Boolean = false): rx.Observer { + val currentTx = contextTransaction val subject = UnicastSubject.create() + val databaseTxBoundary: Observable = currentTx.boundary.filter { it.success } + if (propagateRollbackAsError) { + currentTx.boundary.filter { !it.success }.subscribe { this.onError(DatabaseTransactionRolledBackException(it.txId)) } + } subject.delaySubscription(databaseTxBoundary).subscribe(this) - databaseTxBoundary.doOnCompleted { subject.onCompleted() } return subject } +class DatabaseTransactionRolledBackException(txId: UUID) : Exception("Database transaction $txId was rolled back") + // A subscriber that delegates to multiple others, wrapping a database transaction around the combination. private class DatabaseTransactionWrappingSubscriber(private val db: CordaPersistence?) : Subscriber() { // Some unsubscribes happen inside onNext() so need something that supports concurrent modification. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt index ecc48300e3..578479d39a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt @@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.persistence import co.paralleluniverse.strands.Strand import org.hibernate.Session import org.hibernate.Transaction +import rx.subjects.PublishSubject import java.sql.Connection import java.util.* @@ -35,11 +36,16 @@ class DatabaseTransaction( val session: Session by sessionDelegate private lateinit var hibernateTransaction: Transaction + + internal val boundary = PublishSubject.create() + private var committed = false + fun commit() { if (sessionDelegate.isInitialized()) { hibernateTransaction.commit() } connection.commit() + committed = true } fun rollback() { @@ -58,7 +64,15 @@ class DatabaseTransaction( connection.close() contextTransactionOrNull = outerTransaction if (outerTransaction == null) { - database.transactionBoundaries.onNext(CordaPersistence.Boundary(id)) + boundary.onNext(CordaPersistence.Boundary(id, committed)) } } + + fun onCommit(callback: () -> Unit) { + boundary.filter { it.success }.subscribe { callback() } + } + + fun onRollback(callback: () -> Unit) { + boundary.filter { !it.success }.subscribe { callback() } + } } diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt new file mode 100644 index 0000000000..40bdb29444 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt @@ -0,0 +1,158 @@ +package net.corda.node.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.node.services.Permissions +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.RandomFree +import net.corda.testing.node.User +import org.junit.Before +import org.junit.Test +import java.lang.management.ManagementFactory +import java.sql.SQLException +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + + +class FlowRetryTest { + @Before + fun resetCounters() { + InitiatorFlow.seen.clear() + InitiatedFlow.seen.clear() + } + + @Test + fun `flows continue despite errors`() { + val numSessions = 2 + val numIterations = 10 + val user = User("mark", "dadada", setOf(Permissions.startFlow())) + val result: Any? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(), + portAllocation = RandomFree)) { + + val nodeAHandle = startNode(rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(rpcUsers = listOf(user)).getOrThrow() + + val result = CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow(::InitiatorFlow, numSessions, numIterations, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow() + } + result + } + assertNotNull(result) + assertEquals("$numSessions:$numIterations", result) + } +} + +fun isQuasarAgentSpecified(): Boolean { + val jvmArgs = ManagementFactory.getRuntimeMXBean().inputArguments + return jvmArgs.any { it.startsWith("-javaagent:") && it.contains("quasar") } +} + +class ExceptionToCauseRetry : SQLException("deadlock") + +@StartableByRPC +@InitiatingFlow +class InitiatorFlow(private val sessionsCount: Int, private val iterationsCount: Int, private val other: Party) : FlowLogic() { + companion object { + object FIRST_STEP : ProgressTracker.Step("Step one") + + fun tracker() = ProgressTracker(FIRST_STEP) + + val seen = Collections.synchronizedSet(HashSet()) + + fun visit(sessionNum: Int, iterationNum: Int, step: Step) { + val visited = Visited(sessionNum, iterationNum, step) + if (visited !in seen) { + seen += visited + throw ExceptionToCauseRetry() + } + } + } + + override val progressTracker = tracker() + + @Suspendable + override fun call(): Any { + progressTracker.currentStep = FIRST_STEP + var received: Any? = null + visit(-1, -1, Step.First) + for (sessionNum in 1..sessionsCount) { + visit(sessionNum, -1, Step.BeforeInitiate) + val session = initiateFlow(other) + visit(sessionNum, -1, Step.AfterInitiate) + session.send(SessionInfo(sessionNum, iterationsCount)) + visit(sessionNum, -1, Step.AfterInitiateSendReceive) + for (iteration in 1..iterationsCount) { + visit(sessionNum, iteration, Step.BeforeSend) + logger.info("A Sending $sessionNum:$iteration") + session.send("$sessionNum:$iteration") + visit(sessionNum, iteration, Step.AfterSend) + received = session.receive().unwrap { it } + visit(sessionNum, iteration, Step.AfterReceive) + logger.info("A Got $sessionNum:$iteration") + } + doSleep() + } + return received!! + } + + // This non-flow-friendly sleep triggered a bug with session end messages and non-retryable checkpoints. + private fun doSleep() { + Thread.sleep(2000) + } +} + +@InitiatedBy(InitiatorFlow::class) +class InitiatedFlow(val session: FlowSession) : FlowLogic() { + companion object { + object FIRST_STEP : ProgressTracker.Step("Step one") + + fun tracker() = ProgressTracker(FIRST_STEP) + + val seen = Collections.synchronizedSet(HashSet()) + + fun visit(sessionNum: Int, iterationNum: Int, step: Step) { + val visited = Visited(sessionNum, iterationNum, step) + if (visited !in seen) { + seen += visited + throw ExceptionToCauseRetry() + } + } + } + + override val progressTracker = tracker() + + @Suspendable + override fun call() { + progressTracker.currentStep = FIRST_STEP + visit(-1, -1, Step.AfterInitiate) + val sessionInfo = session.receive().unwrap { it } + visit(sessionInfo.sessionNum, -1, Step.AfterInitiateSendReceive) + for (iteration in 1..sessionInfo.iterationsCount) { + visit(sessionInfo.sessionNum, iteration, Step.BeforeReceive) + val got = session.receive().unwrap { it } + visit(sessionInfo.sessionNum, iteration, Step.AfterReceive) + logger.info("B Got $got") + logger.info("B Sending $got") + visit(sessionInfo.sessionNum, iteration, Step.BeforeSend) + session.send(got) + visit(sessionInfo.sessionNum, iteration, Step.AfterSend) + } + } +} + +@CordaSerializable +data class SessionInfo(val sessionNum: Int, val iterationsCount: Int) + +enum class Step { First, BeforeInitiate, AfterInitiate, AfterInitiateSendReceive, BeforeSend, AfterSend, BeforeReceive, AfterReceive } + +data class Visited(val sessionNum: Int, val iterationNum: Int, val step: Step) \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt index d6970efa58..110bf1ebb9 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt @@ -25,6 +25,7 @@ import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.testMessage.ScheduledState import net.corda.testMessage.SpentState import net.corda.testing.contracts.DummyContract @@ -96,7 +97,7 @@ class ScheduledFlowIntegrationTests { val aliceClient = CordaRPCClient(alice.rpcAddress).start(rpcUser.username, rpcUser.password) val bobClient = CordaRPCClient(bob.rpcAddress).start(rpcUser.username, rpcUser.password) - val scheduledFor = Instant.now().plusSeconds(20) + val scheduledFor = Instant.now().plusSeconds(10) val initialiseFutures = mutableListOf>() for (i in 0 until N) { initialiseFutures.add(aliceClient.proxy.startFlow(::InsertInitialStateFlow, bob.nodeInfo.legalIdentities.first(), defaultNotaryIdentity, i, scheduledFor).returnValue) @@ -111,6 +112,9 @@ class ScheduledFlowIntegrationTests { } spendAttemptFutures.getOrThrowAll() + // TODO: the queries below are not atomic so we need to allow enough time for the scheduler to finish. Would be better to query scheduler. + Thread.sleep(20.seconds.toMillis()) + val aliceStates = aliceClient.proxy.vaultQuery(ScheduledState::class.java).states.filter { it.state.data.processed } val aliceSpentStates = aliceClient.proxy.vaultQuery(SpentState::class.java).states diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index d3a356cbd8..8a0b1c3203 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -985,8 +985,37 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { - override fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture> { - return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler) + override fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> { + smm.deliverExternalEvent(event) + return event.future + } + + override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { + val startFlowEvent = object : ExternalEvent.ExternalStartFlowEvent, DeduplicationHandler { + override fun insideDatabaseTransaction() {} + + override fun afterDatabaseTransaction() {} + + override val externalCause: ExternalEvent + get() = this + override val deduplicationHandler: DeduplicationHandler + get() = this + + override val flowLogic: FlowLogic + get() = logic + override val context: InvocationContext + get() = context + + override fun wireUpFuture(flowFuture: CordaFuture>) { + _future.captureLater(flowFuture) + } + + private val _future = openFuture>() + override val future: CordaFuture> + get() = _future + + } + return startFlow(startFlowEvent) } override fun invokeFlowAsync( diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index 5d0cf99dde..7901ea7f1e 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -20,6 +20,12 @@ interface CheckpointStorage { */ fun removeCheckpoint(id: StateMachineRunId): Boolean + /** + * Load an existing checkpoint from the store. + * @return the checkpoint, still in serialized form, or null if not found. + */ + fun getCheckpoint(id: StateMachineRunId): SerializedBytes? + /** * Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the * underlying database connection is closed, so any processing should happen before it is closed. diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index df04be4140..a32f3c771a 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -19,9 +19,9 @@ import net.corda.core.utilities.contextLogger import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.cordapp.CordappProviderInternal import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.NetworkMapUpdater +import net.corda.node.services.statemachine.ExternalEvent import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.nodeapi.internal.persistence.CordaPersistence @@ -134,11 +134,17 @@ interface ServiceHubInternal : ServiceHub { interface FlowStarter { /** - * Starts an already constructed flow. Note that you must be on the server thread to call this method. + * Starts an already constructed flow. Note that you must be on the server thread to call this method. This method + * just synthesizes an [ExternalEvent.ExternalStartFlowEvent] and calls the method below. * @param context indicates who started the flow, see: [InvocationContext]. - * @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler] */ - fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture> + fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> + + /** + * Starts a flow as described by an [ExternalEvent.ExternalStartFlowEvent]. If a transient error + * occurs during invocation, it will re-attempt to start the flow. + */ + fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> /** * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index 22de33c70e..a57b2b0f05 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -2,6 +2,7 @@ package net.corda.node.services.events import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationOrigin import net.corda.core.contracts.SchedulableState @@ -10,11 +11,9 @@ import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.StateRef import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory -import net.corda.core.internal.ThreadBox -import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.* import net.corda.core.internal.concurrent.flatMap -import net.corda.core.internal.join -import net.corda.core.internal.until +import net.corda.core.internal.concurrent.openFuture import net.corda.core.node.ServicesForResolution import net.corda.core.schemas.PersistentStateRef import net.corda.core.serialization.SingletonSerializeAsToken @@ -26,8 +25,10 @@ import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.NodePropertiesStore import net.corda.node.services.api.SchedulerService import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.statemachine.ExternalEvent import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX +import net.corda.nodeapi.internal.persistence.contextTransaction import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.mina.util.ConcurrentHashSet import org.slf4j.Logger @@ -156,29 +157,31 @@ class NodeSchedulerService(private val clock: CordaClock, override fun scheduleStateActivity(action: ScheduledStateRef) { log.trace { "Schedule $action" } - if (!schedulerRepo.merge(action)) { - // Only increase the number of unfinished schedules if the state didn't already exist on the queue - unfinishedSchedules.countUp() - } - mutex.locked { - if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) { - // We are earliest - rescheduleWakeUp() - } else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) { - // We were earliest but might not be any more - rescheduleWakeUp() + // Only increase the number of unfinished schedules if the state didn't already exist on the queue + val countUp = !schedulerRepo.merge(action) + contextTransaction.onCommit { + if (countUp) unfinishedSchedules.countUp() + mutex.locked { + if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) { + // We are earliest + rescheduleWakeUp() + } else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) { + // We were earliest but might not be any more + rescheduleWakeUp() + } } } } override fun unscheduleStateActivity(ref: StateRef) { log.trace { "Unschedule $ref" } - if (startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref)) { - unfinishedSchedules.countDown() - } - mutex.locked { - if (nextScheduledAction?.ref == ref) { - rescheduleWakeUp() + val countDown = startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref) + contextTransaction.onCommit { + if (countDown) unfinishedSchedules.countDown() + mutex.locked { + if (nextScheduledAction?.ref == ref) { + rescheduleWakeUp() + } } } } @@ -227,7 +230,12 @@ class NodeSchedulerService(private val clock: CordaClock, schedulerTimerExecutor.join() } - private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler { + private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef, override val flowLogic: FlowLogic, override val context: InvocationContext) : DeduplicationHandler, ExternalEvent.ExternalStartFlowEvent { + override val externalCause: ExternalEvent + get() = this + override val deduplicationHandler: FlowStartDeduplicationHandler + get() = this + override fun insideDatabaseTransaction() { schedulerRepo.delete(scheduledState.ref) } @@ -239,6 +247,18 @@ class NodeSchedulerService(private val clock: CordaClock, override fun toString(): String { return "${javaClass.simpleName}($scheduledState)" } + + override fun wireUpFuture(flowFuture: CordaFuture>) { + _future.captureLater(flowFuture) + val future = _future.flatMap { it.resultFuture } + future.then { + unfinishedSchedules.countDown() + } + } + + private val _future = openFuture>() + override val future: CordaFuture> + get() = _future } private fun onTimeReached(scheduledState: ScheduledStateRef) { @@ -250,11 +270,8 @@ class NodeSchedulerService(private val clock: CordaClock, flowName = scheduledFlow.javaClass.name // TODO refactor the scheduler to store and propagate the original invocation context val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState)) - val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState) - val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture } - future.then { - unfinishedSchedules.countDown() - } + val startFlowEvent = FlowStartDeduplicationHandler(scheduledState, scheduledFlow, context) + flowStarter.startFlow(startFlowEvent) } } } catch (e: Exception) { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 4d13f0df12..9dcb8658dd 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -10,6 +10,8 @@ import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.serialize import net.corda.core.utilities.ByteSequence import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.services.statemachine.ExternalEvent +import net.corda.node.services.statemachine.SenderDeduplicationId import java.time.Instant import javax.annotation.concurrent.ThreadSafe @@ -25,6 +27,12 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe interface MessagingService { + /** + * A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence + * number for message de-duplication at the recipient. + */ + val ourSenderUUID: String + /** * The provided function will be invoked for each received message whose topic and session matches. The callback * will run on the main server thread provided when the messaging service is constructed, and a database @@ -93,11 +101,12 @@ interface MessagingService { /** * Returns an initialised [Message] with the current time, etc, already filled in. * - * @param topicSession identifier for the topic and session the message is sent to. - * @param additionalProperties optional additional message headers. * @param topic identifier for the topic the message is sent to. + * @param data the payload for the message. + * @param deduplicationId optional message deduplication ID including sender identifier. + * @param additionalHeaders optional additional message headers. */ - fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map = emptyMap()): Message + fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map = emptyMap()): Message /** Given information about either a specific node or a service returns its corresponding address */ fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients @@ -106,7 +115,7 @@ interface MessagingService { val myAddress: SingleMessageRecipient } -fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId) +fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), retryId: Long? = null, additionalHeaders: Map = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId) interface MessageHandlerRegistration @@ -152,15 +161,17 @@ object TopicStringValidator { } /** - * This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done + * This handler is used to implement exactly-once delivery of an external event on top of an at-least-once delivery. This is done * using two hooks that are called from the event processor, one called from the database transaction committing the - * side-effect caused by the event, and another one called after the transaction has committed successfully. + * side-effect caused by the external event, and another one called after the transaction has committed successfully. * * For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later * deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries. * * We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the * to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping. + * + * It holds a reference back to the causing external event. */ interface DeduplicationHandler { /** @@ -174,6 +185,11 @@ interface DeduplicationHandler { * cleanup/acknowledgement/stopping of retries. */ fun afterDatabaseTransaction() + + /** + * The external event for which we are trying to reduce from at-least-once delivery to exactly-once. + */ + val externalCause: ExternalEvent } typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt index 76075477c3..dd666f51a3 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt @@ -8,7 +8,6 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import java.io.Serializable import java.time.Instant -import java.util.* import java.util.concurrent.ConcurrentHashMap import javax.persistence.Column import javax.persistence.Entity @@ -18,8 +17,6 @@ import javax.persistence.Id * Encapsulate the de-duplication logic. */ class P2PMessageDeduplicator(private val database: CordaPersistence) { - val ourSenderUUID = UUID.randomUUID().toString() - // A temporary in-memory set of deduplication IDs and associated high water mark details. // When we receive a message we don't persist the ID immediately, // so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index c8e65c7c21..84b9f19490 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -23,6 +23,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multipl import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.services.statemachine.ExternalEvent +import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.PersistentMap import net.corda.nodeapi.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport @@ -124,7 +126,7 @@ class P2PMessagingClient(val config: NodeConfiguration, ) } - private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map) : Message { + class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map) : Message { override val debugTimestamp: Instant = Instant.now() override fun toString() = "$topic#${String(data.bytes)}" } @@ -158,6 +160,8 @@ class P2PMessagingClient(val config: NodeConfiguration, data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress) + override val ourSenderUUID = UUID.randomUUID().toString() + private val messageRedeliveryDelaySeconds = config.p2pMessagingRetry.messageRedeliveryDelay.seconds private val state = ThreadBox(InnerState()) private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap()) @@ -227,7 +231,7 @@ class P2PMessagingClient(val config: NodeConfiguration, producer!!, versionInfo, this@P2PMessagingClient, - ourSenderUUID = deduplicator.ourSenderUUID + ourSenderUUID = ourSenderUUID ) registerBridgeControl(bridgeSession!!, inboxes.toList()) @@ -435,18 +439,23 @@ class P2PMessagingClient(val config: NodeConfiguration, } } - inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler { + private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent { + override val externalCause: ExternalEvent + get() = this + override val deduplicationHandler: MessageDeduplicationHandler + get() = this + override fun insideDatabaseTransaction() { - deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId) + deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId) } override fun afterDatabaseTransaction() { - deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId) + deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId) messagingExecutor!!.acknowledge(artemisMessage) } override fun toString(): String { - return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})" + return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})" } } @@ -610,8 +619,8 @@ class P2PMessagingClient(val config: NodeConfiguration, handlers.remove(registration.topic) } - override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map): Message { - return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map): Message { + return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders) } override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index be1e01b1a5..73105d598e 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -10,9 +10,9 @@ import net.corda.nodeapi.internal.persistence.currentDBSession import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.io.Serializable import java.util.* import java.util.stream.Stream -import java.io.Serializable import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Id @@ -53,6 +53,11 @@ class DBCheckpointStorage : CheckpointStorage { return session.createQuery(delete).executeUpdate() > 0 } + override fun getCheckpoint(id: StateMachineRunId): SerializedBytes? { + val bytes = currentDBSession().get(DBCheckpoint::class.java, id.uuid.toString())?.checkpoint ?: return null + return SerializedBytes(bytes) + } + override fun getAllCheckpoints(): Stream>> { val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 3271208e95..05b99bc25b 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -21,7 +21,6 @@ import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import rx.Observable import rx.subjects.PublishSubject import java.io.Serializable -import java.util.* import javax.persistence.* // cache value type to just store the immutable bits of a signed transaction plus conversion helpers @@ -71,11 +70,11 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S // to the memory pressure at all here. private const val transactionSignatureOverheadEstimate = 1024 - private fun weighTx(tx: Optional): Int { - if (!tx.isPresent) { + private fun weighTx(tx: AppendOnlyPersistentMapBase.Transactional): Int { + val actTx = tx.valueWithoutIsolation + if (actTx == null) { return 0 } - val actTx = tx.get() return actTx.second.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.first.size } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt index 07d4ceb866..a48d03be7c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -24,7 +24,7 @@ sealed class Action { data class SendInitial( val party: Party, val initialise: InitialSessionMessage, - val deduplicationId: DeduplicationId + val deduplicationId: SenderDeduplicationId ) : Action() /** @@ -33,7 +33,7 @@ sealed class Action { data class SendExisting( val peerParty: Party, val message: ExistingSessionMessage, - val deduplicationId: DeduplicationId + val deduplicationId: SenderDeduplicationId ) : Action() /** @@ -62,7 +62,8 @@ sealed class Action { */ data class PropagateErrors( val errorMessages: List, - val sessions: List + val sessions: List, + val senderUUID: String? ) : Action() /** @@ -129,6 +130,11 @@ sealed class Action { * Release soft locks associated with given ID (currently the flow ID). */ data class ReleaseSoftLocks(val uuid: UUID?) : Action() + + /** + * Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details. + */ + data class RetryFlowFromSafePoint(val currentState: StateMachineState) : Action() } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index 683cd720e3..8aa57d69ef 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -73,6 +73,7 @@ class ActionExecutorImpl( is Action.CommitTransaction -> executeCommitTransaction() is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action) is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action) + is Action.RetryFlowFromSafePoint -> executeRetryFlowFromSafePoint(action) } } @@ -125,7 +126,7 @@ class ActionExecutorImpl( @Suspendable private fun executePropagateErrors(action: Action.PropagateErrors) { action.errorMessages.forEach { (exception) -> - log.debug("Propagating error", exception) + log.warn("Propagating error", exception) } for (sessionState in action.sessions) { // We cannot propagate if the session isn't live. @@ -137,7 +138,7 @@ class ActionExecutorImpl( val sinkSessionId = sessionState.initiatedState.peerSinkSessionId val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) - flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) + flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID)) } } } @@ -226,6 +227,10 @@ class ActionExecutorImpl( ) } + private fun executeRetryFlowFromSafePoint(action: Action.RetryFlowFromSafePoint) { + stateMachineManager.retryFlowFromSafePoint(action.currentState) + } + private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes { return checkpoint.serialize(context = checkpointSerializationContext) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt index 7e853dc5a8..b6d2bbb89b 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt @@ -45,3 +45,9 @@ data class DeduplicationId(val toString: String) { } } } + +/** + * Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be + * null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID. + */ +data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?) \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt index 921795038c..c750a05d1e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -31,9 +31,9 @@ sealed class Event { */ data class DeliverSessionMessage( val sessionMessage: ExistingSessionMessage, - val deduplicationHandler: DeduplicationHandler, + override val deduplicationHandler: DeduplicationHandler, val sender: Party - ) : Event() + ) : Event(), GeneratedByExternalEvent /** * Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error. @@ -133,4 +133,19 @@ sealed class Event { * @param returnValue the result of the operation. */ data class AsyncOperationCompletion(val returnValue: Any?) : Event() + + /** + * Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details. + */ + object RetryFlowFromSafePoint : Event() { + override fun toString() = "RetryFlowFromSafePoint" + } + + /** + * Indicates that an event was generated by an external event and that external event needs to be replayed if we retry the flow, + * even if it has not yet been processed and placed on the pending de-duplication handlers list. + */ + interface GeneratedByExternalEvent { + val deduplicationHandler: DeduplicationHandler + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt index bcd60557df..68ea53724e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt @@ -9,10 +9,15 @@ interface FlowHospital { /** * The flow running in [flowFiber] has errored. */ - fun flowErrored(flowFiber: FlowFiber) + fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List) /** * The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume. */ fun flowCleaned(flowFiber: FlowFiber) + + /** + * The flow has been removed from the state machine. + */ + fun flowRemoved(flowFiber: FlowFiber) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt index 0f47417eb8..1ca7490e7d 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -24,7 +24,7 @@ interface FlowMessaging { * listen on the send acknowledgement. */ @Suspendable - fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) + fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId) /** * Start the messaging using the [onMessage] message handler. @@ -49,7 +49,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { } @Suspendable - override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) { + override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId) { log.trace { "Sending message $deduplicationId $message to party $party" } val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party)) val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index fe60503096..3d2b460152 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -47,6 +47,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*> private val log: Logger = LoggerFactory.getLogger("net.corda.flow") + + private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null) } override val serviceHub get() = getTransientField(TransientValues::serviceHub) @@ -65,6 +67,14 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, internal var transientValues: TransientReference? = null internal var transientState: TransientReference? = null + /** + * What sender identifier to put on messages sent by this flow. This will either be the identifier for the current + * state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and + * the de-duplication of messages it sends should not be optimised since this could be unreliable. + */ + override val ourSenderUUID: String? + get() = transientState?.value?.senderUUID + private fun getTransientField(field: KProperty1): A { val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!") return field.get(suppliedValues.value) @@ -168,6 +178,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, fun setLoggingContext() { context.pushToLoggingContext() MDC.put("flow-id", id.uuid.toString()) + MDC.put("fiber-id", this.getId().toString()) } @Suspendable @@ -185,7 +196,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true) Try.Success(result) } catch (throwable: Throwable) { - logger.warn("Flow threw exception", throwable) + logger.info("Flow threw exception... sending to flow hospital", throwable) Try.Failure(throwable) } val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null @@ -325,7 +336,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, isDbTransactionOpenOnExit = false ) require(continuation == FlowContinuation.ProcessEvents) - Fiber.unparkDeserialized(this, scheduler) + unpark(SERIALIZER_BLOCKER) } setLoggingContext() return uncheckedCast(processEventsUntilFlowIsResumed( diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt index a31656db5d..49f5fdb167 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt @@ -9,12 +9,17 @@ import net.corda.core.utilities.loggerFor object PropagatingFlowHospital : FlowHospital { private val log = loggerFor() - override fun flowErrored(flowFiber: FlowFiber) { - log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" } + override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List) { + log.debug { "Flow ${flowFiber.id} in state $currentState encountered error" } flowFiber.scheduleEvent(Event.StartErrorPropagation) + for ((index, error) in errors.withIndex()) { + log.warn("Flow ${flowFiber.id} is propagating error [$index] ", error) + } } override fun flowCleaned(flowFiber: FlowFiber) { throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered") } + + override fun flowRemoved(flowFiber: FlowFiber) {} } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index 1fca4d4501..a0d492f50e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -46,13 +46,15 @@ import java.security.SecureRandom import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService +import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.ThreadSafe import kotlin.collections.ArrayList +import kotlin.concurrent.withLock import kotlin.streams.toList /** * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which - * thread actually starts them via [startFlow]. + * thread actually starts them via [deliverExternalEvent]. */ @ThreadSafe class SingleThreadedStateMachineManager( @@ -90,6 +92,7 @@ class SingleThreadedStateMachineManager( private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub) private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null private val transitionExecutor = makeTransitionExecutor() + private val ourSenderUUID = serviceHub.networkService.ourSenderUUID private var checkpointSerializationContext: SerializationContext? = null private var tokenizableServices: List? = null @@ -128,7 +131,7 @@ class SingleThreadedStateMachineManager( resumeRestoredFlows(fibers) flowMessaging.start { receivedMessage, deduplicationHandler -> executor.execute { - onSessionMessage(receivedMessage, deduplicationHandler) + deliverExternalEvent(deduplicationHandler.externalCause) } } } @@ -176,7 +179,7 @@ class SingleThreadedStateMachineManager( } } - override fun startFlow( + private fun startFlow( flowLogic: FlowLogic, context: InvocationContext, ourIdentity: Party?, @@ -310,7 +313,73 @@ class SingleThreadedStateMachineManager( } } - private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) { + override fun retryFlowFromSafePoint(currentState: StateMachineState) { + // Get set of external events + val flowId = currentState.flowLogic.runId + val oldFlowLeftOver = mutex.locked { flows[flowId] }?.fiber?.transientValues?.value?.eventQueue + if (oldFlowLeftOver == null) { + logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.") + return + } + val flow = if (currentState.isAnyCheckpointPersisted) { + val serializedCheckpoint = checkpointStorage.getCheckpoint(flowId) + if (serializedCheckpoint == null) { + logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.") + return + } + val checkpoint = deserializeCheckpoint(serializedCheckpoint) + if (checkpoint == null) { + logger.error("Unable to deserialize database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.") + return + } + // Resurrect flow + createFlowFromCheckpoint( + id = flowId, + checkpoint = checkpoint, + initialDeduplicationHandler = null, + isAnyCheckpointPersisted = true, + isStartIdempotent = false, + senderUUID = null + ) + } else { + // Just flow initiation message + null + } + externalEventMutex.withLock { + if (flow != null) addAndStartFlow(flowId, flow) + // Deliver all the external events from the old flow instance. + val unprocessedExternalEvents = mutableListOf() + do { + val event = oldFlowLeftOver.tryReceive() + if (event is Event.GeneratedByExternalEvent) { + unprocessedExternalEvents += event.deduplicationHandler.externalCause + } + } while (event != null) + val externalEvents = currentState.pendingDeduplicationHandlers.map { it.externalCause } + unprocessedExternalEvents + for (externalEvent in externalEvents) { + deliverExternalEvent(externalEvent) + } + } + } + + private val externalEventMutex = ReentrantLock() + override fun deliverExternalEvent(event: ExternalEvent) { + externalEventMutex.withLock { + when (event) { + is ExternalEvent.ExternalMessageEvent -> onSessionMessage(event) + is ExternalEvent.ExternalStartFlowEvent<*> -> onExternalStartFlow(event) + } + } + } + + private fun onExternalStartFlow(event: ExternalEvent.ExternalStartFlowEvent) { + val future = startFlow(event.flowLogic, event.context, ourIdentity = null, deduplicationHandler = event.deduplicationHandler) + event.wireUpFuture(future) + } + + private fun onSessionMessage(event: ExternalEvent.ExternalMessageEvent) { + val message: ReceivedMessage = event.receivedMessage + val deduplicationHandler: DeduplicationHandler = event.deduplicationHandler val peer = message.peer val sessionMessage = try { message.data.deserialize() @@ -384,7 +453,7 @@ class SingleThreadedStateMachineManager( } if (replyError != null) { - flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) + flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID)) deduplicationHandler.afterDatabaseTransaction() } } @@ -458,7 +527,8 @@ class SingleThreadedStateMachineManager( isAnyCheckpointPersisted = false, isStartIdempotent = isStartIdempotent, isRemoved = false, - flowLogic = flowLogic + flowLogic = flowLogic, + senderUUID = ourSenderUUID ) flowStateMachineImpl.transientState = TransientReference(initialState) mutex.locked { @@ -493,7 +563,7 @@ class SingleThreadedStateMachineManager( private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture): FlowStateMachineImpl.TransientValues { return FlowStateMachineImpl.TransientValues( - eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK), + eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK), resultFuture = resultFuture, database = database, transitionExecutor = transitionExecutor, @@ -509,7 +579,8 @@ class SingleThreadedStateMachineManager( checkpoint: Checkpoint, isAnyCheckpointPersisted: Boolean, isStartIdempotent: Boolean, - initialDeduplicationHandler: DeduplicationHandler? + initialDeduplicationHandler: DeduplicationHandler?, + senderUUID: String? = ourSenderUUID ): Flow { val flowState = checkpoint.flowState val resultFuture = openFuture() @@ -524,7 +595,8 @@ class SingleThreadedStateMachineManager( isAnyCheckpointPersisted = isAnyCheckpointPersisted, isStartIdempotent = isStartIdempotent, isRemoved = false, - flowLogic = logic + flowLogic = logic, + senderUUID = senderUUID ) val fiber = FlowStateMachineImpl(id, logic, scheduler) fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) @@ -542,7 +614,8 @@ class SingleThreadedStateMachineManager( isAnyCheckpointPersisted = isAnyCheckpointPersisted, isStartIdempotent = isStartIdempotent, isRemoved = false, - flowLogic = fiber.logic + flowLogic = fiber.logic, + senderUUID = senderUUID ) fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) fiber.transientState = TransientReference(state) @@ -566,9 +639,13 @@ class SingleThreadedStateMachineManager( startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping")) logger.trace("Not resuming as SMM is stopping.") } else { - incrementLiveFibers() - unfinishedFibers.countUp() - flows[id] = flow + val oldFlow = flows.put(id, flow) + if (oldFlow == null) { + incrementLiveFibers() + unfinishedFibers.countUp() + } else { + oldFlow.resultFuture.captureLater(flow.resultFuture) + } flow.fiber.scheduleEvent(Event.DoRemainingWork) when (checkpoint.flowState) { is FlowState.Unstarted -> { @@ -604,7 +681,7 @@ class SingleThreadedStateMachineManager( private fun makeTransitionExecutor(): TransitionExecutor { val interceptors = ArrayList() - interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) } + interceptors.add { HospitalisingInterceptor(StaffedFlowHospital, it) } if (serviceHub.configuration.devMode) { interceptors.add { DumpHistoryOnErrorInterceptor(it) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt new file mode 100644 index 0000000000..b0fb7943f0 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt @@ -0,0 +1,127 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.StateMachineRunId +import net.corda.core.utilities.loggerFor +import java.sql.SQLException +import java.time.Instant +import java.util.concurrent.ConcurrentHashMap + +/** + * This hospital consults "staff" to see if they can automatically diagnose and treat flows. + */ +object StaffedFlowHospital : FlowHospital { + private val log = loggerFor() + + private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist) + + private val patients = ConcurrentHashMap() + + val numberOfPatients = patients.size + + class MedicalHistory { + val records: MutableList = mutableListOf() + + sealed class Record(val suspendCount: Int) { + class Admitted(val at: Instant, suspendCount: Int) : Record(suspendCount) { + override fun toString() = "Admitted(at=$at, suspendCount=$suspendCount)" + } + + class Discharged(val at: Instant, suspendCount: Int, val by: Staff, val error: Throwable) : Record(suspendCount) { + override fun toString() = "Discharged(at=$at, suspendCount=$suspendCount, by=$by)" + } + } + + fun notDischargedForTheSameThingMoreThan(max: Int, by: Staff): Boolean { + val lastAdmittanceSuspendCount = (records.last() as MedicalHistory.Record.Admitted).suspendCount + return records.filterIsInstance(MedicalHistory.Record.Discharged::class.java).filter { it.by == by && it.suspendCount == lastAdmittanceSuspendCount }.count() <= max + } + + override fun toString(): String = "${this.javaClass.simpleName}(records = $records)" + } + + override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List) { + log.info("Flow ${flowFiber.id} admitted to hospital in state $currentState") + val medicalHistory = patients.computeIfAbsent(flowFiber.id) { MedicalHistory() } + medicalHistory.records += MedicalHistory.Record.Admitted(Instant.now(), currentState.checkpoint.numberOfSuspends) + for ((index, error) in errors.withIndex()) { + log.info("Flow ${flowFiber.id} has error [$index]", error) + if (!errorIsDischarged(flowFiber, currentState, error, medicalHistory)) { + // If any error isn't discharged, then we propagate. + log.warn("Flow ${flowFiber.id} error was not discharged, propagating.") + flowFiber.scheduleEvent(Event.StartErrorPropagation) + return + } + } + // If all are discharged, retry. + flowFiber.scheduleEvent(Event.RetryFlowFromSafePoint) + } + + private fun errorIsDischarged(flowFiber: FlowFiber, currentState: StateMachineState, error: Throwable, medicalHistory: MedicalHistory): Boolean { + for (staffMember in staff) { + val diagnosis = staffMember.consult(flowFiber, currentState, error, medicalHistory) + if (diagnosis == Diagnosis.DISCHARGE) { + medicalHistory.records += MedicalHistory.Record.Discharged(Instant.now(), currentState.checkpoint.numberOfSuspends, staffMember, error) + log.info("Flow ${flowFiber.id} error discharged from hospital by $staffMember") + return true + } + } + return false + } + + // It's okay for flows to be cleaned... we fix them now! + override fun flowCleaned(flowFiber: FlowFiber) {} + + override fun flowRemoved(flowFiber: FlowFiber) { + patients.remove(flowFiber.id) + } + + enum class Diagnosis { + /** + * Retry from last safe point. + */ + DISCHARGE, + /** + * Please try another member of staff. + */ + NOT_MY_SPECIALTY + } + + interface Staff { + fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis + } + + /** + * SQL Deadlock detection. + */ + object DeadlockNurse : Staff { + override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis { + return if (mentionsDeadlock(newError)) { + Diagnosis.DISCHARGE + } else { + Diagnosis.NOT_MY_SPECIALTY + } + } + + private fun mentionsDeadlock(exception: Throwable?): Boolean { + return exception != null && (exception is SQLException && ((exception.message?.toLowerCase()?.contains("deadlock") + ?: false)) || mentionsDeadlock(exception.cause)) + } + } + + /** + * Primary key violation detection for duplicate inserts. Will detect other constraint violations too. + */ + object DuplicateInsertSpecialist : Staff { + override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis { + return if (mentionsConstraintViolation(newError) && history.notDischargedForTheSameThingMoreThan(3, this)) { + Diagnosis.DISCHARGE + } else { + Diagnosis.NOT_MY_SPECIALTY + } + } + + private fun mentionsConstraintViolation(exception: Throwable?): Boolean { + return exception != null && (exception is org.hibernate.exception.ConstraintViolationException || mentionsConstraintViolation(exception.cause)) + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 087408c265..084452aa25 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -4,11 +4,11 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId -import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.DataFeed import net.corda.core.utilities.Try import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.messaging.ReceivedMessage import rx.Observable /** @@ -40,21 +40,6 @@ interface StateMachineManager { */ fun stop(allowedUnsuspendedFiberCount: Int) - /** - * Starts a new flow. - * - * @param flowLogic The flow's code. - * @param context The context of the flow. - * @param ourIdentity The identity to use for the flow. - * @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler]. - */ - fun startFlow( - flowLogic: FlowLogic, - context: InvocationContext, - ourIdentity: Party?, - deduplicationHandler: DeduplicationHandler? - ): CordaFuture> - /** * Represents an addition/removal of a state machine. */ @@ -91,6 +76,12 @@ interface StateMachineManager { * @return whether the flow existed and was killed. */ fun killFlow(id: StateMachineRunId): Boolean + + /** + * Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow. + * The event may be replayed if a flow fails and attempts to retry. + */ + fun deliverExternalEvent(event: ExternalEvent) } // These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call @@ -100,4 +91,38 @@ interface StateMachineManagerInternal { fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) fun removeSessionBindings(sessionIds: Set) fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) + fun retryFlowFromSafePoint(currentState: StateMachineState) +} + +/** + * Represents an external event that can be injected into the state machine and that might need to be replayed if + * a flow retries. They always have de-duplication handlers to assist with the at-most once logic where required. + */ +interface ExternalEvent { + val deduplicationHandler: DeduplicationHandler + + /** + * An external P2P message event. + */ + interface ExternalMessageEvent : ExternalEvent { + val receivedMessage: ReceivedMessage + } + + /** + * An external request to start a flow, from the scheduler for example. + */ + interface ExternalStartFlowEvent : ExternalEvent { + val flowLogic: FlowLogic + val context: InvocationContext + + /** + * A callback for the state machine to pass back the [Future] associated with the flow start to the submitter. + */ + fun wireUpFuture(flowFuture: CordaFuture>) + + /** + * The future representing the flow start, passed back from the state machine to the submitter of this event. + */ + val future: CordaFuture> + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index b2d572b475..9ff1edd3ca 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -26,6 +26,7 @@ import net.corda.node.services.messaging.DeduplicationHandler * possible. * @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further * work. + * @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking. */ // TODO perhaps add a read-only environment to the state machine for things that don't change over time? // TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do. @@ -37,7 +38,8 @@ data class StateMachineState( val isTransactionTracked: Boolean, val isAnyCheckpointPersisted: Boolean, val isStartIdempotent: Boolean, - val isRemoved: Boolean + val isRemoved: Boolean, + val senderUUID: String? ) /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt index 23ee4a2d9f..2e57e0bb14 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt @@ -34,7 +34,8 @@ class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : Transiti (record ?: ArrayList()).apply { add(transitionRecord) } } - if (nextState.checkpoint.errorState is ErrorState.Errored) { + // Just if we decide to propagate, and not if just on the way to the hospital. + if (nextState.checkpoint.errorState is ErrorState.Errored && nextState.checkpoint.errorState.propagating) { log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}") for (error in nextState.checkpoint.errorState.errors) { log.warn("Flow ${fiber.id} error", error.exception) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt index 8ed5f67b95..f463017d62 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt @@ -26,20 +26,23 @@ class HospitalisingInterceptor( actionExecutor: ActionExecutor ): Pair { val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) - when (nextState.checkpoint.errorState) { - ErrorState.Clean -> { - if (hospitalisedFlows.remove(fiber.id) != null) { - flowHospital.flowCleaned(fiber) + + when (nextState.checkpoint.errorState) { + is ErrorState.Clean -> { + if (hospitalisedFlows.remove(fiber.id) != null) { + flowHospital.flowCleaned(fiber) + } + } + is ErrorState.Errored -> { + val exceptionsToHandle = nextState.checkpoint.errorState.errors.map { it.exception } + if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) { + flowHospital.flowErrored(fiber, previousState, exceptionsToHandle) + } } } - is ErrorState.Errored -> { - if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) { - flowHospital.flowErrored(fiber) - } - } - } if (nextState.isRemoved) { hospitalisedFlows.remove(fiber.id) + flowHospital.flowRemoved(fiber) } return Pair(continuation, nextState) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt index 87fcb49ca6..11d9d771cd 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -46,9 +46,6 @@ class DeliverSessionMessageTransition( is EndSessionMessage -> endMessageTransition() } } - if (!isErrored()) { - persistCheckpoint() - } // Schedule a DoRemainingWork to check whether the flow needs to be woken up. actions.add(Action.ScheduleEvent(Event.DoRemainingWork)) FlowContinuation.ProcessEvents @@ -73,7 +70,7 @@ class DeliverSessionMessageTransition( // Send messages that were buffered pending confirmation of session. val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) -> val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage) - Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId) + Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)) } actions.addAll(sendActions) currentState = currentState.copy(checkpoint = newCheckpoint) @@ -146,21 +143,6 @@ class DeliverSessionMessageTransition( } } - private fun TransitionBuilder.persistCheckpoint() { - // We persist the message as soon as it arrives. - actions.addAll(arrayOf( - Action.CreateTransaction, - Action.PersistCheckpoint(context.id, currentState.checkpoint), - Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), - Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers) - )) - currentState = currentState.copy( - pendingDeduplicationHandlers = emptyList(), - isAnyCheckpointPersisted = true - ) - } - private fun TransitionBuilder.endMessageTransition() { val sessionId = event.sessionMessage.recipientSessionId val sessions = currentState.checkpoint.sessions diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index 97cb3be926..89b1b00a29 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -46,7 +46,7 @@ class ErrorFlowTransition( sessions = newSessions ) currentState = currentState.copy(checkpoint = newCheckpoint) - actions.add(Action.PropagateErrors(errorMessages, initiatedSessions)) + actions.add(Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)) } // If we're errored but not propagating keep processing events. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt index 28d1d04486..f2a09939b4 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -216,7 +216,7 @@ class StartedFlowTransition( } val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null) - actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId)) + actions.add(Action.SendInitial(sessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) newSessions[sourceSessionId] = SessionState.Initiating( bufferedMessages = emptyList(), rejectionError = null @@ -253,7 +253,7 @@ class StartedFlowTransition( when (existingSessionState) { is SessionState.Uninitiated -> { val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message) - actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId)) + actions.add(Action.SendInitial(existingSessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) newSessions[sourceSessionId] = SessionState.Initiating( bufferedMessages = emptyList(), rejectionError = null @@ -270,7 +270,7 @@ class StartedFlowTransition( is InitiatedSessionState.Live -> { val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage) - actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId)) + actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) Unit } InitiatedSessionState.Ended -> { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index a0cdc389f2..5fc4133e98 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -18,18 +18,19 @@ class TopLevelTransition( ) : Transition { override fun transition(): TransitionResult { return when (event) { - is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition() - is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition() - is Event.Error -> errorTransition(event) - is Event.TransactionCommitted -> transactionCommittedTransition(event) - is Event.SoftShutdown -> softShutdownTransition() - is Event.StartErrorPropagation -> startErrorPropagationTransition() - is Event.EnterSubFlow -> enterSubFlowTransition(event) - is Event.LeaveSubFlow -> leaveSubFlowTransition() - is Event.Suspend -> suspendTransition(event) - is Event.FlowFinish -> flowFinishTransition(event) - is Event.InitiateFlow -> initiateFlowTransition(event) - is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event) + is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition() + is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition() + is Event.Error -> errorTransition(event) + is Event.TransactionCommitted -> transactionCommittedTransition(event) + is Event.SoftShutdown -> softShutdownTransition() + is Event.StartErrorPropagation -> startErrorPropagationTransition() + is Event.EnterSubFlow -> enterSubFlowTransition(event) + is Event.LeaveSubFlow -> leaveSubFlowTransition() + is Event.Suspend -> suspendTransition(event) + is Event.FlowFinish -> flowFinishTransition(event) + is Event.InitiateFlow -> initiateFlowTransition(event) + is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event) + is Event.RetryFlowFromSafePoint -> retryFlowFromSafePointTransition(startingState) } } @@ -191,7 +192,7 @@ class TopLevelTransition( if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index) - Action.SendExisting(state.peerParty, message, deduplicationId) + Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID)) } else { null } @@ -230,4 +231,14 @@ class TopLevelTransition( resumeFlowLogic(event.returnValue) } } + + private fun retryFlowFromSafePointTransition(startingState: StateMachineState): TransitionResult { + return builder { + // Need to create a flow from the prior checkpoint or flow initiation. + actions.add(Action.CreateTransaction) + actions.add(Action.RetryFlowFromSafePoint(startingState)) + actions.add(Action.CommitTransaction) + FlowContinuation.Abort + } + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt index bd9f30c65c..7737127e16 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -58,7 +58,7 @@ class UnstartedFlowTransition( Action.SendExisting( flowStart.peerSession.counterparty, sessionMessage, - DeduplicationId.createForNormal(currentState.checkpoint, 0) + SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0), currentState.senderUUID) ) ) } diff --git a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt index f7065c9807..639a02ff85 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt @@ -3,14 +3,21 @@ package net.corda.node.utilities import com.github.benmanes.caffeine.cache.LoadingCache import com.github.benmanes.caffeine.cache.Weigher import net.corda.core.utilities.contextLogger +import net.corda.nodeapi.internal.persistence.DatabaseTransaction +import net.corda.nodeapi.internal.persistence.contextTransaction import net.corda.nodeapi.internal.persistence.currentDBSession +import java.lang.ref.WeakReference import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference /** - * Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice the - * behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so - * ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY + * Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice, + * typically this will result in a duplicate insert if this is racing with another transaction. The flow framework will then retry. + * + * This class relies heavily on the fact that compute operations in the cache are atomic for a particular key. */ abstract class AppendOnlyPersistentMapBase( val toPersistentEntityKey: (K) -> EK, @@ -23,7 +30,8 @@ abstract class AppendOnlyPersistentMapBase( private val log = contextLogger() } - protected abstract val cache: LoadingCache> + protected abstract val cache: LoadingCache> + protected val pendingKeys = ConcurrentHashMap>() /** * Returns the value associated with the key, first loading that value from the storage if necessary. @@ -47,32 +55,31 @@ abstract class AppendOnlyPersistentMapBase( return result.map { x -> fromPersistentEntity(x) }.asSequence() } - private tailrec fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean { - var insertionAttempt = false - var isUnique = true - val existingInCache = cache.get(key) { - // Thread safe, if multiple threads may wait until the first one has loaded. - insertionAttempt = true - // Key wasn't in the cache and might be in the underlying storage. - // Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key. - val existingInDb = store(key, value) - if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key. - isUnique = false - Optional.of(existingInDb) - } else { - Optional.of(value) - } - }!! - if (!insertionAttempt) { - if (existingInCache.isPresent) { - // Key already exists in cache, do nothing. - isUnique = false - } else { - // This happens when the key was queried before with no value associated. We invalidate the cached null - // value and recursively call set again. This is to avoid race conditions where another thread queries after - // the invalidate but before the set. - cache.invalidate(key!!) - return set(key, value, logWarning, store) + private fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean { + // Will be set to true if store says it isn't in the database. + var isUnique = false + cache.asMap().compute(key) { _, oldValue -> + // Always write to the database, unless we can see it's already committed. + when (oldValue) { + is Transactional.InFlight<*, V> -> { + // Someone else is writing, so store away! + // TODO: we can do collision detection here and prevent it happening in the database. But we also have to do deadlock detection, so a bit of work. + isUnique = (store(key, value) == null) + oldValue.apply { alsoWrite(value) } + } + is Transactional.Committed -> oldValue // The value is already globally visible and cached. So do nothing since the values are always the same. + else -> { + // Null or Missing. Store away! + isUnique = (store(key, value) == null) + if (!isUnique && !weAreWriting(key)) { + // If we found a value already in the database, and we were not already writing, then it's already committed but got evicted. + Transactional.Committed(value) + } else { + // Some database transactions, including us, writing, with readers seeing whatever is in the database and writers seeing the (in memory) value. + Transactional.InFlight(this, key, { loadValue(key) }).apply { alsoWrite(value) } + } + } + } } if (logWarning && !isUnique) { @@ -93,7 +100,8 @@ abstract class AppendOnlyPersistentMapBase( /** * Associates the specified value with the specified key in this map and persists it. - * If the map previously contained a mapping for the key, the old value is not replaced. + * If the map previously contained a committed mapping for the key, the old value is not replaced. It may throw an error from the + * underlying storage if this races with another database transaction to store a value for the same key. * @return true if added key was unique, otherwise false */ fun addWithDuplicatesAllowed(key: K, value: V, logWarning: Boolean = true): Boolean = @@ -116,7 +124,7 @@ abstract class AppendOnlyPersistentMapBase( protected fun loadValue(key: K): V? { val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key)) - return result?.let(fromPersistentEntity)?.second + return result?.apply { currentDBSession().detach(result) }?.let(fromPersistentEntity)?.second } operator fun contains(key: K) = get(key) != null @@ -132,9 +140,161 @@ abstract class AppendOnlyPersistentMapBase( session.createQuery(deleteQuery).executeUpdate() cache.invalidateAll() } + + // Helpers to know if transaction(s) are currently writing the given key. + protected fun weAreWriting(key: K): Boolean = pendingKeys.get(key)?.contains(contextTransaction) ?: false + protected fun anyoneWriting(key: K): Boolean = pendingKeys.get(key)?.isNotEmpty() ?: false + + // Indicate this database transaction is a writer of this key. + private fun addPendingKey(key: K, databaseTransaction: DatabaseTransaction): Boolean { + var added = true + pendingKeys.compute(key) { k, oldSet -> + if (oldSet == null) { + val newSet = HashSet(0) + newSet += databaseTransaction + newSet + } else { + added = oldSet.add(databaseTransaction) + oldSet + } + } + return added + } + + // Remove this database transaction as a writer of this key, because the transaction committed or rolled back. + private fun removePendingKey(key: K, databaseTransaction: DatabaseTransaction) { + pendingKeys.compute(key) { k, oldSet -> + if (oldSet == null) { + oldSet + } else { + oldSet -= databaseTransaction + if (oldSet.size == 0) null else oldSet + } + } + } + + /** + * Represents a value in the cache, with transaction isolation semantics. + * + * There are 3 states. Globally missing, globally visible, and being written in a transaction somewhere now or in + * the past (and it rolled back). + */ + sealed class Transactional { + abstract val value: T + abstract val isPresent: Boolean + abstract val valueWithoutIsolation: T? + + fun orElse(alt: T?) = if (isPresent) value else alt + + // Everyone can see it, and database transaction committed. + class Committed(override val value: T) : Transactional() { + override val isPresent: Boolean + get() = true + override val valueWithoutIsolation: T? + get() = value + } + + // No one can see it. + class Missing() : Transactional() { + override val value: T + get() = throw NoSuchElementException("Not present") + override val isPresent: Boolean + get() = false + override val valueWithoutIsolation: T? + get() = null + } + + // Written in a transaction (uncommitted) somewhere, but there's a small window when this might be seen after commit, + // hence the committed flag. + class InFlight(private val map: AppendOnlyPersistentMapBase, + private val key: K, + private val _readerValueLoader: () -> T?, + private val _writerValueLoader: () -> T = { throw IllegalAccessException("No value loader provided") }) : Transactional() { + + // A flag to indicate this has now been committed, but hasn't yet been replaced with Committed. This also + // de-duplicates writes of the Committed value to the cache. + private val committed = AtomicBoolean(false) + + // What to do if a non-writer needs to see the value and it hasn't yet been committed to the database. + // Can be updated into a no-op once evaluated. + private val readerValueLoader = AtomicReference<() -> T?>(_readerValueLoader) + // What to do if a writer needs to see the value and it hasn't yet been committed to the database. + // Can be updated into a no-op once evaluated. + private val writerValueLoader = AtomicReference<() -> T>(_writerValueLoader) + + fun alsoWrite(_value: T) { + // Make the lazy loader the writers see actually just return the value that has been set. + writerValueLoader.set({ _value }) + // We make all these vals so that the lambdas do not need a reference to this, and so the onCommit only has a weak ref to the value. + // We want this so that the cache could evict the value (due to memory constraints etc) without the onCommit callback + // retaining what could be a large memory footprint object. + val tx = contextTransaction + val strongKey = key + val weakValue = WeakReference(_value) + val strongComitted = committed + val strongMap = map + if (map.addPendingKey(key, tx)) { + // If the transaction commits, update cache to make globally visible if we're first for this key, + // and then stop saying the transaction is writing the key. + tx.onCommit { + if (strongComitted.compareAndSet(false, true)) { + val dereferencedKey = strongKey + val dereferencedValue = weakValue.get() + if (dereferencedValue != null) { + strongMap.cache.put(dereferencedKey, Committed(dereferencedValue)) + } + } + strongMap.removePendingKey(strongKey, tx) + } + // If the transaction rolls back, stop saying this transaction is writing the key. + tx.onRollback { + strongMap.removePendingKey(strongKey, tx) + } + } + } + + // Lazy load the value a "writer" would see. If the original loader hasn't been replaced, replace it + // with one that just returns the value once evaluated. + private fun loadAsWriter(): T { + val _value = writerValueLoader.get()() + if (writerValueLoader.get() == _writerValueLoader) { + writerValueLoader.set({ _value }) + } + return _value + } + + // Lazy load the value a "reader" would see. If the original loader hasn't been replaced, replace it + // with one that just returns the value once evaluated. + private fun loadAsReader(): T? { + val _value = readerValueLoader.get()() + if (readerValueLoader.get() == _readerValueLoader) { + readerValueLoader.set({ _value }) + } + return _value + } + + // Whether someone reading (only) can see the entry. + private val isPresentAsReader: Boolean get() = (loadAsReader() != null) + // Whether the entry is already written and committed, or we are writing it (and thus it can be seen). + private val isPresentAsWriter: Boolean get() = committed.get() || map.weAreWriting(key) + + override val isPresent: Boolean + get() = isPresentAsWriter || isPresentAsReader + + // If it is committed or we are writing, reveal the value, potentially lazy loading from the database. + // If none of the above, see what was already in the database, potentially lazily. + override val value: T + get() = if (isPresentAsWriter) loadAsWriter() else if (isPresentAsReader) loadAsReader()!! else throw NoSuchElementException("Not present") + + // The value from the perspective of the eviction algorithm of the cache. i.e. we want to reveal memory footprint to it etc. + override val valueWithoutIsolation: T? + get() = if (writerValueLoader.get() != _writerValueLoader) writerValueLoader.get()() else if (readerValueLoader.get() != _writerValueLoader) readerValueLoader.get()() else null + } + } } -class AppendOnlyPersistentMap( +// Open for tests to override +open class AppendOnlyPersistentMap( toPersistentEntityKey: (K) -> EK, fromPersistentEntity: (E) -> Pair, toPersistentEntity: (key: K, value: V) -> E, @@ -146,26 +306,71 @@ class AppendOnlyPersistentMap( toPersistentEntity, persistentEntityClass) { //TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size - override val cache = NonInvalidatingCache>( + override val cache = NonInvalidatingCache>( bound = cacheBound, - loadFunction = { key -> Optional.ofNullable(loadValue(key)) }) + loadFunction = { key: K -> + // This gets called if a value is read and the cache has no Transactional for this key yet. + val value: V? = loadValue(key) + if (value == null) { + // No visible value + if (anyoneWriting(key)) { + // If someone is writing (but not us) + // For those not writing, the value cannot be seen. + // For those writing, they need to re-load the value from the database (which their database transaction CAN see). + Transactional.InFlight(this, key, { null }, { loadValue(key)!! }) + } else { + // If no one is writing, then the value does not exist. + Transactional.Missing() + } + } else { + // A value was found + if (weAreWriting(key)) { + // If we are writing, it might not be globally visible, and was evicted from the cache. + // For those not writing, they need to check the database again. + // For those writing, they can see the value found. + Transactional.InFlight(this, key, { loadValue(key) }, { value }) + } else { + // If no one is writing, then make it globally visible. + Transactional.Committed(value) + } + } + }) } +// Same as above, but with weighted values (e.g. memory footprint sensitive). class WeightBasedAppendOnlyPersistentMap( toPersistentEntityKey: (K) -> EK, fromPersistentEntity: (E) -> Pair, toPersistentEntity: (key: K, value: V) -> E, persistentEntityClass: Class, maxWeight: Long, - weighingFunc: (K, Optional) -> Int + weighingFunc: (K, Transactional) -> Int ) : AppendOnlyPersistentMapBase( toPersistentEntityKey, fromPersistentEntity, toPersistentEntity, persistentEntityClass) { - override val cache = NonInvalidatingWeightBasedCache( + override val cache = NonInvalidatingWeightBasedCache>( maxWeight = maxWeight, - weigher = Weigher> { key, value -> weighingFunc(key, value) }, - loadFunction = { key -> Optional.ofNullable(loadValue(key)) } - ) -} \ No newline at end of file + weigher = object : Weigher> { + override fun weigh(key: K, value: Transactional): Int { + return weighingFunc(key, value) + } + }, + loadFunction = { key: K -> + val value: V? = loadValue(key) + if (value == null) { + if (anyoneWriting(key)) { + Transactional.InFlight(this, key, { null }, { loadValue(key)!! }) + } else { + Transactional.Missing() + } + } else { + if (weAreWriting(key)) { + Transactional.InFlight(this, key, { loadValue(key) }, { value }) + } else { + Transactional.Committed(value) + } + } + }) +} diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index ec6eb06604..3255e92b7a 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -14,14 +14,15 @@ import net.corda.node.internal.configureDatabase import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.NodePropertiesStore import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.statemachine.ExternalEvent import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig -import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.testing.internal.doLookup import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.spectator import net.corda.testing.node.MockServices import net.corda.testing.node.TestClock +import org.junit.After import org.junit.Ignore import org.junit.Rule import org.junit.Test @@ -50,7 +51,7 @@ open class NodeSchedulerServiceTestBase { dedupe.insideDatabaseTransaction() dedupe.afterDatabaseTransaction() openFuture>() - }.whenever(it).startFlow(any>(), any(), any()) + }.whenever(it).startFlow(any>()) } private val flowsDraingMode = rigorousMock().also { doReturn(false).whenever(it).isEnabled() @@ -80,7 +81,7 @@ open class NodeSchedulerServiceTestBase { protected fun assertStarted(flowLogic: FlowLogic<*>) { // Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow: - verify(flowStarter, timeout(5000)).startFlow(same(flowLogic), any(), any()) + verify(flowStarter, timeout(5000)).startFlow(argForWhich> { this.flowLogic == flowLogic }) } protected fun assertStarted(event: Event) = assertStarted(event.flowLogic) @@ -112,11 +113,11 @@ class MockScheduledFlowRepository : ScheduledFlowRepository { } class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { - private val database = rigorousMock().also { - doAnswer { - val block: DatabaseTransaction.() -> Any? = it.getArgument(0) - rigorousMock().block() - }.whenever(it).transaction(any()) + private val database = configureDatabase(MockServices.makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock()) + + @After + fun closeDatabase() { + database.close() } private val scheduler = NodeSchedulerService( @@ -148,7 +149,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { }).whenever(it).data } flows[logicRef] = flowLogic - scheduler.scheduleStateActivity(ssr) + database.transaction { + scheduler.scheduleStateActivity(ssr) + } } @Test @@ -207,7 +210,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { fun `test activity due in the future and schedule another for same time then unschedule second`() { val eventA = schedule(mark + 1.days) val eventB = schedule(mark + 1.days) - scheduler.unscheduleStateActivity(eventB.stateRef) + database.transaction { + scheduler.unscheduleStateActivity(eventB.stateRef) + } assertWaitingFor(eventA) testClock.advanceBy(1.days) assertStarted(eventA) @@ -217,7 +222,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { fun `test activity due in the future and schedule another for same time then unschedule original`() { val eventA = schedule(mark + 1.days) val eventB = schedule(mark + 1.days) - scheduler.unscheduleStateActivity(eventA.stateRef) + database.transaction { + scheduler.unscheduleStateActivity(eventA.stateRef) + } assertWaitingFor(eventB) testClock.advanceBy(1.days) assertStarted(eventB) @@ -225,7 +232,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { @Test fun `test activity due in the future then unschedule`() { - scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef) + database.transaction { + scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef) + } testClock.advanceBy(1.days) } } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt new file mode 100644 index 0000000000..b10042dcc6 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt @@ -0,0 +1,290 @@ +package net.corda.node.services.persistence + +import net.corda.core.schemas.MappedSchema +import net.corda.core.utilities.loggerFor +import net.corda.node.internal.configureDatabase +import net.corda.node.services.schema.NodeSchemaService +import net.corda.node.utilities.AppendOnlyPersistentMap +import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.internal.rigorousMock +import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties +import org.junit.After +import org.junit.Assert.* +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.io.Serializable +import java.util.concurrent.CountDownLatch +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id +import javax.persistence.PersistenceException + +@RunWith(Parameterized::class) +class AppendOnlyPersistentMapTest(var scenario: Scenario) { + companion object { + + private val scenarios = arrayOf( + Scenario(false, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Fail, Outcome.Fail), + Scenario(false, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success), + Scenario(false, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Fail, Outcome.Success), + Scenario(false, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit), + Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success), + Scenario(false, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Success), + Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.SuccessButErrorOnCommit, Outcome.Fail), + Scenario(true, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Success, Outcome.Success), + Scenario(true, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.SuccessButErrorOnCommit, Outcome.Success), + Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail), + Scenario(true, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit), + Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Fail, Outcome.Success), + Scenario(true, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.Fail), + Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Fail) + ) + + @Parameterized.Parameters(name = "{0}") + @JvmStatic + fun data(): Array> = scenarios.map { arrayOf(it) }.toTypedArray() + } + + enum class ReadOrWrite { Read, Write, WriteDuplicateAllowed } + enum class Outcome { Success, Fail, SuccessButErrorOnCommit } + + data class Scenario(val prePopulated: Boolean, + val a: ReadOrWrite, + val b: ReadOrWrite, + val aExpected: Outcome, + val bExpected: Outcome, + val bExpectedIfSingleThreaded: Outcome = bExpected) + + private val database = configureDatabase(makeTestDataSourceProperties(), + DatabaseConfig(), + rigorousMock(), + NodeSchemaService(setOf(MappedSchema(AppendOnlyPersistentMapTest::class.java, 1, listOf(PersistentMapEntry::class.java))))) + + @After + fun closeDatabase() { + database.close() + } + + @Test + fun `concurrent test no purge between A and B`() { + prepopulateIfRequired() + val map = createMap() + val a = TestThread("A", map).apply { start() } + val b = TestThread("B", map).apply { start() } + + // Begin A + a.phase1.countDown() + a.await(a::phase2) + + // Begin B + b.phase1.countDown() + b.await(b::phase2) + + // Commit A + a.phase3.countDown() + a.await(a::phase4) + + // Commit B + b.phase3.countDown() + b.await(b::phase4) + + // End + a.join() + b.join() + assertTrue(map.pendingKeysIsEmpty()) + } + + @Test + fun `test no purge with only a single transaction`() { + prepopulateIfRequired() + val map = createMap() + val a = TestThread("A", map, true).apply { + phase1.countDown() + phase3.countDown() + } + val b = TestThread("B", map, true).apply { + phase1.countDown() + phase3.countDown() + } + try { + database.transaction { + a.run() + b.run() + } + } catch (t: PersistenceException) { + // This only helps if thrown on commit, otherwise other latches not counted down. + assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome) + } + a.await(a::phase4) + b.await(b::phase4) + assertTrue(map.pendingKeysIsEmpty()) + } + + + @Test + fun `concurrent test purge between A and B`() { + // Writes intentionally do not check the database first, so purging between read and write changes behaviour + val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit)) + scenario = remapped[scenario] ?: scenario + prepopulateIfRequired() + val map = createMap() + val a = TestThread("A", map).apply { start() } + val b = TestThread("B", map).apply { start() } + + // Begin A + a.phase1.countDown() + a.await(a::phase2) + + map.invalidate() + + // Begin B + b.phase1.countDown() + b.await(b::phase2) + + // Commit A + a.phase3.countDown() + a.await(a::phase4) + + // Commit B + b.phase3.countDown() + b.await(b::phase4) + + // End + a.join() + b.join() + assertTrue(map.pendingKeysIsEmpty()) + } + + @Test + fun `test purge mid-way in a single transaction`() { + // Writes intentionally do not check the database first, so purging between read and write changes behaviour + val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit)) + scenario = remapped[scenario] ?: scenario + prepopulateIfRequired() + val map = createMap() + val a = TestThread("A", map, true).apply { + phase1.countDown() + phase3.countDown() + } + val b = TestThread("B", map, true).apply { + phase1.countDown() + phase3.countDown() + } + try { + database.transaction { + a.run() + map.invalidate() + b.run() + } + } catch (t: PersistenceException) { + // This only helps if thrown on commit, otherwise other latches not counted down. + assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome) + } + a.await(a::phase4) + b.await(b::phase4) + assertTrue(map.pendingKeysIsEmpty()) + } + + inner class TestThread(name: String, val map: AppendOnlyPersistentMap, singleThreaded: Boolean = false) : Thread(name) { + private val log = loggerFor() + + val readOrWrite = if (name == "A") scenario.a else scenario.b + val outcome = if (name == "A") scenario.aExpected else if (singleThreaded) scenario.bExpectedIfSingleThreaded else scenario.bExpected + + val phase1 = latch() + val phase2 = latch() + val phase3 = latch() + val phase4 = latch() + + override fun run() { + try { + database.transaction { + await(::phase1) + doActivity() + phase2.countDown() + await(::phase3) + } + } catch (t: PersistenceException) { + // This only helps if thrown on commit, otherwise other latches not counted down. + assertEquals(t.message, Outcome.SuccessButErrorOnCommit, outcome) + } + phase4.countDown() + } + + private fun doActivity() { + if (readOrWrite == ReadOrWrite.Read) { + log.info("Reading") + val value = map.get(1) + log.info("Read $value") + if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) { + assertEquals("X", value) + } else { + assertNull(value) + } + } else if (readOrWrite == ReadOrWrite.Write) { + log.info("Writing") + val wasSet = map.set(1, "X") + log.info("Write $wasSet") + if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) { + assertEquals(true, wasSet) + } else { + assertEquals(false, wasSet) + } + } else if (readOrWrite == ReadOrWrite.WriteDuplicateAllowed) { + log.info("Writing with duplicates allowed") + val wasSet = map.addWithDuplicatesAllowed(1, "X") + log.info("Write with duplicates allowed $wasSet") + if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) { + assertEquals(true, wasSet) + } else { + assertEquals(false, wasSet) + } + } + } + + private fun latch() = CountDownLatch(1) + fun await(latch: () -> CountDownLatch) { + log.info("Awaiting $latch") + latch().await() + } + } + + private fun prepopulateIfRequired() { + if (scenario.prePopulated) { + database.transaction { + val map = createMap() + map.set(1, "X") + } + } + } + + @Entity + @javax.persistence.Table(name = "persist_map_test") + class PersistentMapEntry( + @Id + @Column(name = "key") + var key: Long = -1, + + @Column(name = "value", length = 16) + var value: String = "" + ) : Serializable + + class TestMap : AppendOnlyPersistentMap( + toPersistentEntityKey = { it }, + fromPersistentEntity = { Pair(it.key, it.value) }, + toPersistentEntity = { key: Long, value: String -> + PersistentMapEntry().apply { + this.key = key + this.value = value + } + }, + persistentEntityClass = PersistentMapEntry::class.java + ) { + fun pendingKeysIsEmpty() = pendingKeys.isEmpty() + + fun invalidate() = cache.invalidateAll() + } + + fun createMap() = TestMap() +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/TransactionCallbackTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/TransactionCallbackTest.kt new file mode 100644 index 0000000000..cd46899392 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/persistence/TransactionCallbackTest.kt @@ -0,0 +1,49 @@ +package net.corda.node.services.persistence + +import net.corda.node.internal.configureDatabase +import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.internal.rigorousMock +import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties +import org.junit.After +import org.junit.Test +import kotlin.test.assertEquals + + +class TransactionCallbackTest { + private val database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock()) + + @After + fun closeDatabase() { + database.close() + } + + @Test + fun `onCommit called and onRollback not called on commit`() { + var onCommitCount = 0 + var onRollbackCount = 0 + database.transaction { + onCommit { onCommitCount++ } + onRollback { onRollbackCount++ } + } + assertEquals(1, onCommitCount) + assertEquals(0, onRollbackCount) + } + + @Test + fun `onCommit not called and onRollback called on rollback`() { + class TestException : Exception() + + var onCommitCount = 0 + var onRollbackCount = 0 + try { + database.transaction { + onCommit { onCommitCount++ } + onRollback { onRollbackCount++ } + throw TestException() + } + } catch (e: TestException) { + } + assertEquals(0, onCommitCount) + assertEquals(1, onRollbackCount) + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt new file mode 100644 index 0000000000..12b8d8af23 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt @@ -0,0 +1,166 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.identity.Party +import net.corda.core.messaging.MessageRecipients +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.node.internal.StartedNode +import net.corda.node.services.messaging.Message +import net.corda.node.services.persistence.DBTransactionStorage +import net.corda.nodeapi.internal.persistence.contextTransaction +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.MessagingServiceSpy +import net.corda.testing.node.internal.newContext +import net.corda.testing.node.internal.setMessagingServiceSpy +import org.assertj.core.api.Assertions +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.sql.SQLException +import java.time.Duration +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull + +class RetryFlowMockTest { + private lateinit var mockNet: InternalMockNetwork + private lateinit var internalNodeA: StartedNode + private lateinit var internalNodeB: StartedNode + + @Before + fun start() { + mockNet = InternalMockNetwork(threadPerNode = true, cordappPackages = listOf(this.javaClass.`package`.name)) + internalNodeA = mockNet.createNode() + internalNodeB = mockNet.createNode() + mockNet.startNodes() + RetryFlow.count = 0 + SendAndRetryFlow.count = 0 + RetryInsertFlow.count = 0 + } + + private fun StartedNode.startFlow(logic: FlowLogic): CordaFuture = this.services.startFlow(logic, this.services.newContext()).getOrThrow().resultFuture + + @After + fun cleanUp() { + mockNet.stopNodes() + } + + @Test + fun `Single retry`() { + assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get()) + assertEquals(2, RetryFlow.count) + } + + @Test + fun `Retry forever`() { + Assertions.assertThatThrownBy { + internalNodeA.startFlow(RetryFlow(Int.MAX_VALUE)).getOrThrow() + }.isInstanceOf(LimitedRetryCausingError::class.java) + assertEquals(5, RetryFlow.count) + } + + @Test + fun `Retry does not set senderUUID`() { + val messagesSent = mutableListOf() + val partyB = internalNodeB.info.legalIdentities.first() + internalNodeA.setMessagingServiceSpy(object : MessagingServiceSpy(internalNodeA.network) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { + messagesSent.add(message) + messagingService.send(message, target, retryId) + } + }) + internalNodeA.startFlow(SendAndRetryFlow(1, partyB)).get() + assertNotNull(messagesSent.first().senderUUID) + assertNull(messagesSent.last().senderUUID) + assertEquals(2, SendAndRetryFlow.count) + } + + @Test + fun `Retry duplicate insert`() { + assertEquals(Unit, internalNodeA.startFlow(RetryInsertFlow(1)).get()) + assertEquals(2, RetryInsertFlow.count) + } + + @Test + fun `Patient records do not leak in hospital`() { + assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get()) + assertEquals(0, StaffedFlowHospital.numberOfPatients) + assertEquals(2, RetryFlow.count) + } +} + +class LimitedRetryCausingError : org.hibernate.exception.ConstraintViolationException("Test message", SQLException(), "Test constraint") + +class RetryCausingError : SQLException("deadlock") + +class RetryFlow(val i: Int) : FlowLogic() { + companion object { + var count = 0 + } + + @Suspendable + override fun call() { + logger.info("Hello $count") + if (count++ < i) { + if (i == Int.MAX_VALUE) { + throw LimitedRetryCausingError() + } else { + throw RetryCausingError() + } + } + } +} + +@InitiatingFlow +class SendAndRetryFlow(val i: Int, val other: Party) : FlowLogic() { + companion object { + var count = 0 + } + + @Suspendable + override fun call() { + logger.info("Sending...") + val session = initiateFlow(other) + session.send("Boo") + if (count++ < i) { + throw RetryCausingError() + } + } +} + +@InitiatedBy(SendAndRetryFlow::class) +class ReceiveFlow2(val other: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val received = other.receive().unwrap { it } + logger.info("Received... $received") + } +} + +class RetryInsertFlow(val i: Int) : FlowLogic() { + companion object { + var count = 0 + } + + @Suspendable + override fun call() { + logger.info("Hello") + doInsert() + // Checkpoint so we roll back to here + FlowLogic.sleep(Duration.ofSeconds(0)) + if (count++ < i) { + doInsert() + } + } + + private fun doInsert() { + val tx = DBTransactionStorage.DBTransaction("Foo") + contextTransaction.session.save(tx) + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index cdba4ced28..f64efcf72f 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -28,7 +28,10 @@ import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.testing.core.* import net.corda.testing.internal.TEST_TX_TIME import net.corda.testing.internal.rigorousMock -import net.corda.testing.internal.vault.* +import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID +import net.corda.testing.internal.vault.DummyLinearContract +import net.corda.testing.internal.vault.DummyLinearStateSchemaV1 +import net.corda.testing.internal.vault.VaultFiller import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices import net.corda.testing.node.makeTestIdentityService @@ -171,17 +174,11 @@ abstract class VaultQueryTestsBase : VaultQueryParties { @JvmField val expectedEx: ExpectedException = ExpectedException.none() - @Suppress("LeakingThis") - @Rule - @JvmField - val transactionRule = VaultQueryRollbackRule(this) - companion object { @ClassRule @JvmField val testSerialization = SerializationEnvironmentRule() } - /** * Helper method for generating a Persistent H2 test database */ @@ -194,7 +191,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { database.close() } - private fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) + protected fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { _database.transaction { // create new states @@ -1988,239 +1985,6 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } - /** - * Dynamic trackBy() tests - */ - - @Test - fun trackCashStates_unconsumed() { - val updates = database.transaction { - val updates = - // DOCSTART VaultQueryExample15 - vaultService.trackBy().updates // UNCONSUMED default - // DOCEND VaultQueryExample15 - - vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) - val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states - val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states - // add more cash - vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) - // add another deal - vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) - this.session.flush() - - // consume stuff - consumeCash(100.DOLLARS) - vaultFiller.consumeDeals(dealStates.toList()) - vaultFiller.consumeLinearStates(linearStates.toList()) - - close() // transaction needs to be closed to trigger events - updates - } - - updates.expectEvents { - sequence( - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 5) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 1) {} - } - ) - } - } - - @Test - fun trackCashStates_consumed() { - - val updates = database.transaction { - val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) - val updates = vaultService.trackBy(criteria).updates - - vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) - val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states - val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states - // add more cash - vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) - // add another deal - vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) - this.session.flush() - - consumeCash(100.POUNDS) - - // consume more stuff - consumeCash(100.DOLLARS) - vaultFiller.consumeDeals(dealStates.toList()) - vaultFiller.consumeLinearStates(linearStates.toList()) - - close() // transaction needs to be closed to trigger events - updates - } - - updates.expectEvents { - sequence( - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.size == 1) {} - require(produced.isEmpty()) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.size == 5) {} - require(produced.isEmpty()) {} - } - ) - } - } - - @Test - fun trackCashStates_all() { - val updates = database.transaction { - val updates = - database.transaction { - val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) - vaultService.trackBy(criteria).updates - } - vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) - val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states - val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states - // add more cash - vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) - // add another deal - vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) - this.session.flush() - -// consume stuff - consumeCash(99.POUNDS) - - consumeCash(100.DOLLARS) - vaultFiller.consumeDeals(dealStates.toList()) - vaultFiller.consumeLinearStates(linearStates.toList()) - - close() // transaction needs to be closed to trigger events - updates - } - - updates.expectEvents { - sequence( - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 5) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 1) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.size == 1) {} - require(produced.size == 1) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.size == 5) {} - require(produced.isEmpty()) {} - } - ) - } - } - - @Test - fun trackLinearStates() { - - val updates = database.transaction { - // DOCSTART VaultQueryExample16 - val (snapshot, updates) = vaultService.trackBy() - // DOCEND VaultQueryExample16 - assertThat(snapshot.states).hasSize(0) - - vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER) - val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states - val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states - // add more cash - vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) - // add another deal - vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) - this.session.flush() - - // consume stuff - consumeCash(100.DOLLARS) - vaultFiller.consumeDeals(dealStates.toList()) - vaultFiller.consumeLinearStates(linearStates.toList()) - - close() // transaction needs to be closed to trigger events - updates - } - - updates.expectEvents { - sequence( - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 10) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 3) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 1) {} - } - ) - } - } - - @Test - fun trackDealStates() { - val updates = database.transaction { - // DOCSTART VaultQueryExample17 - val (snapshot, updates) = vaultService.trackBy() - // DOCEND VaultQueryExample17 - assertThat(snapshot.states).hasSize(0) - - vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER) - val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states - val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states - // add more cash - vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) - // add another deal - vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) - this.session.flush() - - // consume stuff - consumeCash(100.DOLLARS) - vaultFiller.consumeDeals(dealStates.toList()) - vaultFiller.consumeLinearStates(linearStates.toList()) - - close() - updates - } - - updates.expectEvents { - sequence( - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 3) {} - }, - expect { (consumed, produced, flowId) -> - require(flowId == null) {} - require(consumed.isEmpty()) {} - require(produced.size == 1) {} - } - ) - } - } - @Test fun unconsumedCashStatesForSpending_single_issuer_reference() { database.transaction { @@ -2281,10 +2045,241 @@ abstract class VaultQueryTestsBase : VaultQueryParties { */ } -class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by vaultQueryTestRule { +class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { companion object { - @ClassRule @JvmField - val vaultQueryTestRule = VaultQueryTestRule() + val delegate = VaultQueryTestRule() + } + + @Rule + @JvmField + val vaultQueryTestRule = delegate + + /** + * Dynamic trackBy() tests are H2 only, since rollback stops events being emitted. + */ + + @Test + fun trackCashStates_unconsumed() { + val updates = database.transaction { + val updates = + // DOCSTART VaultQueryExample15 + vaultService.trackBy().updates // UNCONSUMED default + // DOCEND VaultQueryExample15 + + vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) + val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states + val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states + // add more cash + vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) + // add another deal + vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) + this.session.flush() + + // consume stuff + consumeCash(100.DOLLARS) + vaultFiller.consumeDeals(dealStates.toList()) + vaultFiller.consumeLinearStates(linearStates.toList()) + + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 5) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 1) {} + } + ) + } + } + + @Test + fun trackCashStates_consumed() { + + val updates = database.transaction { + val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) + val updates = vaultService.trackBy(criteria).updates + + vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) + val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states + val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states + // add more cash + vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) + // add another deal + vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) + this.session.flush() + + consumeCash(100.POUNDS) + + // consume more stuff + consumeCash(100.DOLLARS) + vaultFiller.consumeDeals(dealStates.toList()) + vaultFiller.consumeLinearStates(linearStates.toList()) + + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.size == 1) {} + require(produced.isEmpty()) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.size == 5) {} + require(produced.isEmpty()) {} + } + ) + } + } + + @Test + fun trackCashStates_all() { + val updates = database.transaction { + val updates = + database.transaction { + val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) + vaultService.trackBy(criteria).updates + } + vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER) + val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states + val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states + // add more cash + vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) + // add another deal + vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) + this.session.flush() + + // consume stuff + consumeCash(99.POUNDS) + + consumeCash(100.DOLLARS) + vaultFiller.consumeDeals(dealStates.toList()) + vaultFiller.consumeLinearStates(linearStates.toList()) + + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 5) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 1) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.size == 1) {} + require(produced.size == 1) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.size == 5) {} + require(produced.isEmpty()) {} + } + ) + } + } + + @Test + fun trackLinearStates() { + + val updates = database.transaction { + // DOCSTART VaultQueryExample16 + val (snapshot, updates) = vaultService.trackBy() + // DOCEND VaultQueryExample16 + assertThat(snapshot.states).hasSize(0) + + vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER) + val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states + val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states + // add more cash + vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) + // add another deal + vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) + this.session.flush() + + // consume stuff + consumeCash(100.DOLLARS) + vaultFiller.consumeDeals(dealStates.toList()) + vaultFiller.consumeLinearStates(linearStates.toList()) + + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 10) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 3) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 1) {} + } + ) + } + } + + @Test + fun trackDealStates() { + val updates = database.transaction { + // DOCSTART VaultQueryExample17 + val (snapshot, updates) = vaultService.trackBy() + // DOCEND VaultQueryExample17 + assertThat(snapshot.states).hasSize(0) + + vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER) + val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states + val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states + // add more cash + vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER) + // add another deal + vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) + this.session.flush() + + // consume stuff + consumeCash(100.DOLLARS) + vaultFiller.consumeDeals(dealStates.toList()) + vaultFiller.consumeLinearStates(linearStates.toList()) + + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 3) {} + }, + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 1) {} + } + ) + } } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt index ab4227230b..c2986223cb 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt @@ -5,8 +5,8 @@ import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.tee import net.corda.node.internal.configureDatabase import net.corda.nodeapi.internal.persistence.* -import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.internal.rigorousMock +import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Test @@ -14,6 +14,7 @@ import rx.Observable import rx.subjects.PublishSubject import java.io.Closeable import java.util.* +import kotlin.test.fail class ObservablesTests { private fun isInDatabaseTransaction() = contextTransactionOrNull != null @@ -58,6 +59,72 @@ class ObservablesTests { assertThat(secondEvent.get()).isEqualTo(0 to false) } + class TestException : Exception("Synthetic exception for tests") {} + + @Test + fun `bufferUntilDatabaseCommit swallows if transaction rolled back`() { + val database = createDatabase() + + val source = PublishSubject.create() + val observable: Observable = source + + val firstEvent = SettableFuture.create>() + val secondEvent = SettableFuture.create>() + + observable.first().subscribe { firstEvent.set(it to isInDatabaseTransaction()) } + observable.skip(1).first().subscribe { secondEvent.set(it to isInDatabaseTransaction()) } + + try { + database.transaction { + val delayedSubject = source.bufferUntilDatabaseCommit() + assertThat(source).isNotEqualTo(delayedSubject) + delayedSubject.onNext(0) + source.onNext(1) + assertThat(firstEvent.isDone).isTrue() + assertThat(secondEvent.isDone).isFalse() + throw TestException() + } + fail("Should not have successfully completed transaction") + } catch (e: TestException) { + } + assertThat(secondEvent.isDone).isFalse() + + assertThat(firstEvent.get()).isEqualTo(1 to true) + } + + @Test + fun `bufferUntilDatabaseCommit propagates error if transaction rolled back`() { + val database = createDatabase() + + val source = PublishSubject.create() + val observable: Observable = source + + val firstEvent = SettableFuture.create>() + val secondEvent = SettableFuture.create>() + + observable.first().subscribe({ firstEvent.set(it to isInDatabaseTransaction()) }, {}) + observable.skip(1).subscribe({ secondEvent.set(it to isInDatabaseTransaction()) }, {}) + observable.skip(1).subscribe({}, { secondEvent.set(2 to isInDatabaseTransaction()) }) + + try { + database.transaction { + val delayedSubject = source.bufferUntilDatabaseCommit(propagateRollbackAsError = true) + assertThat(source).isNotEqualTo(delayedSubject) + delayedSubject.onNext(0) + source.onNext(1) + assertThat(firstEvent.isDone).isTrue() + assertThat(secondEvent.isDone).isFalse() + throw TestException() + } + fail("Should not have successfully completed transaction") + } catch (e: TestException) { + } + assertThat(secondEvent.isDone).isTrue() + + assertThat(firstEvent.get()).isEqualTo(1 to true) + assertThat(secondEvent.get()).isEqualTo(2 to false) + } + @Test fun `bufferUntilDatabaseCommit delays until transaction closed repeatable`() { val database = createDatabase() diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index 5fe858355f..a67fd09f35 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -20,6 +20,8 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace import net.corda.node.services.messaging.* import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.services.statemachine.ExternalEvent +import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.node.internal.InMemoryMessage @@ -108,7 +110,7 @@ class InMemoryMessagingNetwork private constructor( get() = _receivedMessages internal val endpoints: List @Synchronized get() = handleEndpointMap.values.toList() /** Get a [List] of all the [MockMessagingService] endpoints **/ - val endpointsExternal: List @Synchronized get() = handleEndpointMap.values.map{ MockMessagingService.createMockMessagingService(it) }.toList() + val endpointsExternal: List @Synchronized get() = handleEndpointMap.values.map { MockMessagingService.createMockMessagingService(it) }.toList() /** * Creates a node at the given address: useful if you want to recreate a node to simulate a restart. @@ -135,7 +137,10 @@ class InMemoryMessagingNetwork private constructor( ?: emptyList() //TODO only notary can be distributed? synchronized(this) { val node = InMemoryMessaging(manuallyPumped, peerHandle, executor, database) - handleEndpointMap[peerHandle] = node + val oldNode = handleEndpointMap.put(peerHandle, node) + if (oldNode != null) { + node.inheritPendingRedelivery(oldNode) + } serviceHandles.forEach { serviceToPeersMapping.getOrPut(it) { LinkedHashSet() }.add(peerHandle) } @@ -161,7 +166,10 @@ class InMemoryMessagingNetwork private constructor( @Synchronized private fun netNodeHasShutdown(peerHandle: PeerHandle) { - handleEndpointMap.remove(peerHandle) + val endpoint = handleEndpointMap[peerHandle] + if (!(endpoint?.hasPendingDeliveries() ?: false)) { + handleEndpointMap.remove(peerHandle) + } } @Synchronized @@ -266,6 +274,30 @@ class InMemoryMessagingNetwork private constructor( return transfer } + /** + * When a new message handler is added, this implies we have started a new node. The add handler logic uses this to + * push back any un-acknowledged messages for this peer onto the head of the queue (rather than the tail) to maintain message + * delivery order. We push them back because their consumption was not complete and a restarted node would + * see them re-delivered if this was Artemis. + */ + @Synchronized + private fun unPopMessages(transfers: Collection, us: PeerHandle) { + messageReceiveQueues.compute(us) { _, existing -> + if (existing == null) { + LinkedBlockingQueue().apply { + addAll(transfers) + } + } else { + existing.apply { + val drained = mutableListOf() + existing.drainTo(drained) + existing.addAll(transfers) + existing.addAll(drained) + } + } + } + } + private fun pumpSendInternal(transfer: MessageTransfer) { when (transfer.recipients) { is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer) @@ -338,6 +370,7 @@ class InMemoryMessagingNetwork private constructor( private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) override val myAddress: PeerHandle get() = peerHandle + override val ourSenderUUID: String = UUID.randomUUID().toString() private val backgroundThread = if (manuallyPumped) null else thread(isDaemon = true, name = "In-memory message dispatcher") { @@ -370,10 +403,16 @@ class InMemoryMessagingNetwork private constructor( Pair(handler, pending) } - transfers.forEach { pumpSendInternal(it) } + unPopMessages(transfers, peerHandle) return handler } + fun inheritPendingRedelivery(other: InMemoryMessaging) { + state.locked { + pendingRedelivery.addAll(other.state.locked { pendingRedelivery }) + } + } + override fun removeMessageHandler(registration: MessageHandlerRegistration) { check(running) state.locked { check(handlers.remove(registration as Handler)) } @@ -405,8 +444,8 @@ class InMemoryMessagingNetwork private constructor( override fun cancelRedelivery(retryId: Long) {} /** Returns the given (topic & session, data) pair as a newly created message object. */ - override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map): Message { - return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map): Message { + return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID) } /** @@ -470,13 +509,14 @@ class InMemoryMessagingNetwork private constructor( database.transaction { for (handler in deliverTo) { try { - handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler()) + val receivedMessage = transfer.toReceivedMessage() + state.locked { pendingRedelivery.add(transfer) } + handler.callback(receivedMessage, handler, InMemoryDeduplicationHandler(receivedMessage, transfer)) } catch (e: Exception) { log.error("Caught exception in handler for $this/${handler.topicSession}", e) } } _receivedMessages.onNext(transfer) - processedMessages += transfer.message.uniqueMessageId messagesInFlight.countDown() } } @@ -493,13 +533,23 @@ class InMemoryMessagingNetwork private constructor( message.uniqueMessageId, message.debugTimestamp, sender.name) - } - private class DummyDeduplicationHandler : DeduplicationHandler { - override fun afterDatabaseTransaction() { - } - override fun insideDatabaseTransaction() { + private inner class InMemoryDeduplicationHandler(override val receivedMessage: ReceivedMessage, val transfer: MessageTransfer) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent { + override val externalCause: ExternalEvent + get() = this + override val deduplicationHandler: DeduplicationHandler + get() = this + + override fun afterDatabaseTransaction() { + this@InMemoryMessaging.state.locked { pendingRedelivery.remove(transfer) } + } + + override fun insideDatabaseTransaction() { + processedMessages += transfer.message.uniqueMessageId + } } + + fun hasPendingDeliveries(): Boolean = state.locked { pendingRedelivery.isNotEmpty() } } }