diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 0d701b9a3f..fae0f1dfcf 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -3996,10 +3996,6 @@ public static final class net.corda.testing.node.InMemoryMessagingNetwork$Compan public int hashCode() @org.jetbrains.annotations.NotNull public String toString() ## -public static final class net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessaging$pumpReceiveInternal$1$1$acknowledgeHandle$1 extends java.lang.Object implements net.corda.node.services.messaging.AcknowledgeHandle - public void acknowledge() - public void persistDeduplicationId() -## public static interface net.corda.testing.node.InMemoryMessagingNetwork$LatencyCalculator @org.jetbrains.annotations.NotNull public abstract java.time.Duration between(net.corda.core.messaging.SingleMessageRecipient, net.corda.core.messaging.SingleMessageRecipient) ## diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/HardRestartTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/HardRestartTest.kt new file mode 100644 index 0000000000..0a6bbb16c4 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/HardRestartTest.kt @@ -0,0 +1,162 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.transpose +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.node.services.Permissions +import net.corda.testing.core.DUMMY_BANK_A_NAME +import net.corda.testing.core.DUMMY_BANK_B_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.OutOfProcess +import net.corda.testing.driver.driver +import net.corda.testing.node.User +import org.junit.Test +import java.util.* +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import kotlin.concurrent.thread + +class HardRestartTest { + @StartableByRPC + @InitiatingFlow + class Ping(val pongParty: Party) : FlowLogic() { + @Suspendable + override fun call() { + val pongSession = initiateFlow(pongParty) + pongSession.sendAndReceive(Unit) + } + } + + @InitiatedBy(Ping::class) + class Pong(val pingSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + pingSession.sendAndReceive(Unit) + } + } + + @Test + fun restartPingPongFlowRandomly() { + val demoUser = User("demo", "demo", setOf(Permissions.startFlow(), Permissions.all())) + driver(DriverParameters(isDebug = true, startNodesInProcess = false)) { + val (a, b) = listOf( + startNode(providedName = DUMMY_BANK_A_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:30000")), + startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000")) + ).transpose().getOrThrow() + + val latch = CountDownLatch(1) + + // We kill -9 and restart the Pong node after a random sleep + val pongRestartThread = thread { + latch.await() + val ms = Random().nextInt(1000) + println("Sleeping $ms ms before kill") + Thread.sleep(ms.toLong()) + (b as OutOfProcess).process.destroyForcibly() + b.stop() + startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000")) + } + CordaRPCClient(a.rpcAddress).use(demoUser.username, demoUser.password) { + val returnValue = it.proxy.startFlow(::Ping, b.nodeInfo.singleIdentity()).returnValue + latch.countDown() + // No matter the kill + returnValue.getOrThrow() + } + + pongRestartThread.join() + } + } + + sealed class RecursiveMode { + data class Top(val otherParty: Party, val initialDepth: Int) : RecursiveMode() + data class Recursive(val otherSession: FlowSession) : RecursiveMode() + } + + @StartableByRPC + @InitiatingFlow + @InitiatedBy(RecursiveB::class) + class RecursiveA(val mode: RecursiveMode) : FlowLogic() { + constructor(otherSession: FlowSession) : this(RecursiveMode.Recursive(otherSession)) + constructor(otherParty: Party, initialDepth: Int) : this(RecursiveMode.Top(otherParty, initialDepth)) + @Suspendable + override fun call(): String { + return when (mode) { + is HardRestartTest.RecursiveMode.Top -> { + val session = initiateFlow(mode.otherParty) + session.sendAndReceive(mode.initialDepth).unwrap { it } + } + is HardRestartTest.RecursiveMode.Recursive -> { + val depth = mode.otherSession.receive().unwrap { it } + val string = if (depth > 0) { + val newSession = initiateFlow(mode.otherSession.counterparty) + newSession.sendAndReceive(depth).unwrap { it } + } else { + "-" + } + mode.otherSession.send(string) + string + } + } + } + } + + @InitiatingFlow + @InitiatedBy(RecursiveA::class) + class RecursiveB(val otherSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val depth = otherSession.receive().unwrap { it } + val newSession = initiateFlow(otherSession.counterparty) + val string = newSession.sendAndReceive(depth - 1).unwrap { it } + otherSession.send(string + ":" + depth) + } + } + + @Test + fun restartRecursiveFlowRandomly() { + val demoUser = User("demo", "demo", setOf(Permissions.startFlow(), Permissions.all())) + driver(DriverParameters(isDebug = true, startNodesInProcess = false)) { + val (a, b) = listOf( + startNode(providedName = DUMMY_BANK_A_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:30000")), + startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000")) + ).transpose().getOrThrow() + + val latch = CountDownLatch(1) + + // We kill -9 and restart the node B after a random sleep + val bRestartThread = thread { + latch.await() + val ms = Random().nextInt(1000) + println("Sleeping $ms ms before kill") + Thread.sleep(ms.toLong()) + (b as OutOfProcess).process.destroyForcibly() + b.stop() + startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000")) + } + val executor = Executors.newFixedThreadPool(8) + try { + val tlRpc = ThreadLocal() + (1 .. 10).map { num -> + executor.fork { + val rpc = tlRpc.get() ?: CordaRPCClient(a.rpcAddress).start(demoUser.username, demoUser.password).proxy.also { tlRpc.set(it) } + val string = rpc.startFlow(::RecursiveA, b.nodeInfo.singleIdentity(), 10).returnValue.getOrThrow() + latch.countDown() + println("$num: $string") + } + }.transpose().getOrThrow() + + bRestartThread.join() + } finally { + executor.shutdown() + } + } + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt index 7608b10765..877ef2471f 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt @@ -180,7 +180,7 @@ class P2PMessagingTest : IntegrationTest() { val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes) it.internalServices.networkService.send(response, request.replyTo) } - handler.acknowledge() + handler.afterDatabaseTransaction() } } return crashingNodes @@ -212,7 +212,7 @@ class P2PMessagingTest : IntegrationTest() { val request = netMessage.data.deserialize() val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes) internalServices.networkService.send(response, request.replyTo) - handle.acknowledge() + handle.afterDatabaseTransaction() } } @@ -239,7 +239,7 @@ class P2PMessagingTest : IntegrationTest() { check(!consumed.getAndSet(true)) { "Called more than once" } check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" } callback(msg) - handle.acknowledge() + handle.afterDatabaseTransaction() } } diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index c323fdd011..6506e377fc 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -57,6 +57,7 @@ import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.events.ScheduledActivityObserver import net.corda.node.services.identity.PersistentIdentityService import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.* import net.corda.node.services.persistence.* @@ -74,7 +75,6 @@ import net.corda.nodeapi.internal.DevIdentityGenerator import net.corda.nodeapi.internal.NodeInfoAndSigned import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.persistence.* -import net.corda.nodeapi.internal.sign import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.HibernateConfiguration @@ -873,8 +873,8 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { - override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { - return smm.startFlow(logic, context) + override fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture> { + return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler) } override fun invokeFlowAsync( diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index dca685d5ad..d589a6c8a6 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -29,6 +29,7 @@ import net.corda.core.utilities.contextLogger import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.cordapp.CordappProviderInternal import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.NetworkMapUpdater import net.corda.node.services.statemachine.FlowStateMachineImpl @@ -137,8 +138,9 @@ interface FlowStarter { /** * Starts an already constructed flow. Note that you must be on the server thread to call this method. * @param context indicates who started the flow, see: [InvocationContext]. + * @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler] */ - fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> + fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture> /** * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index 8a5ec8de03..db3c6a01d7 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -36,10 +36,12 @@ import net.corda.node.MutableClock import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.NodePropertiesStore import net.corda.node.services.api.SchedulerService +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.utilities.PersistentMap import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import org.apache.activemq.artemis.utils.ReusableLatch +import org.apache.mina.util.ConcurrentHashSet import org.slf4j.Logger import java.time.Duration import java.time.Instant @@ -172,6 +174,10 @@ class NodeSchedulerService(private val clock: CordaClock, var rescheduled: GuavaSettableFuture? = null } + // Used to de-duplicate flow starts in case a flow is starting but the corresponding entry hasn't been removed yet + // from the database + private val startingStateRefs = ConcurrentHashSet() + private val mutex = ThreadBox(InnerState()) // We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow. fun start() { @@ -212,7 +218,7 @@ class NodeSchedulerService(private val clock: CordaClock, val previousEarliest = scheduledStatesQueue.peek() scheduledStatesQueue.remove(previousState) scheduledStatesQueue.add(action) - if (previousState == null) { + if (previousState == null && action !in startingStateRefs) { unfinishedSchedules.countUp() } @@ -279,16 +285,34 @@ class NodeSchedulerService(private val clock: CordaClock, schedulerTimerExecutor.join() } + private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler { + override fun insideDatabaseTransaction() { + scheduledStates.remove(scheduledState.ref) + } + + override fun afterDatabaseTransaction() { + startingStateRefs.remove(scheduledState) + } + + override fun toString(): String { + return "${javaClass.simpleName}($scheduledState)" + } + } + private fun onTimeReached(scheduledState: ScheduledStateRef) { var flowName: String? = "(unknown)" try { - database.transaction { - val scheduledFlow = getScheduledFlow(scheduledState) + // We need to check this before the database transaction, otherwise there is a subtle race between a + // doubly-reached deadline and the removal from [startingStateRefs]. + if (scheduledState !in startingStateRefs) { + val scheduledFlow = database.transaction { getScheduledFlow(scheduledState) } if (scheduledFlow != null) { + startingStateRefs.add(scheduledState) flowName = scheduledFlow.javaClass.name // TODO refactor the scheduler to store and propagate the original invocation context val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState)) - val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } + val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState) + val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture } future.then { unfinishedSchedules.countDown() } @@ -327,7 +351,6 @@ class NodeSchedulerService(private val clock: CordaClock, } else -> { log.trace { "Scheduler starting FlowLogic $flowLogic" } - scheduledStates.remove(scheduledState.ref) scheduledStatesQueue.remove(scheduledState) flowLogic } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 63a87f013d..3544334342 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -163,19 +163,29 @@ object TopicStringValidator { } /** - * Represents a to-be-acknowledged message. It has an associated deduplication ID. + * This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done + * using two hooks that are called from the event processor, one called from the database transaction committing the + * side-effect caused by the event, and another one called after the transaction has committed successfully. + * + * For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later + * deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries. + * + * We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the + * to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping. */ -interface AcknowledgeHandle { +interface DeduplicationHandler { /** - * Acknowledge the message. + * This will be run inside a database transaction that commits the side-effect of the event, allowing the + * implementor to persist the event delivery fact atomically with the side-effect. */ - fun acknowledge() + fun insideDatabaseTransaction() /** - * Store the deduplication ID. TODO this should be moved into the flow state machine completely. + * This will be run strictly after the side-effect has been committed successfully and may be used for + * cleanup/acknowledgement/stopping of retries. */ - fun persistDeduplicationId() + fun afterDatabaseTransaction() } -typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, AcknowledgeHandle) -> Unit +typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt index 89f8a69d3e..4782fbcfdd 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt @@ -16,6 +16,7 @@ import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX +import org.apache.mina.util.ConcurrentHashSet import java.time.Instant import java.util.* import java.util.concurrent.TimeUnit @@ -29,6 +30,10 @@ import javax.persistence.Id class P2PMessageDeduplicator(private val database: CordaPersistence) { val ourSenderUUID = UUID.randomUUID().toString() + // A temporary in-memory set of deduplication IDs. When we receive a message we don't persist the ID immediately, + // so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may + // redeliver messages to the same consumer if they weren't ACKed. + private val beingProcessedMessages = ConcurrentHashSet() private val processedMessages = createProcessedMessages() // We add the peer to the key, so other peers cannot attempt malicious meddling with sequence numbers. // Expire after 7 days since we last touched an entry, to avoid infinite growth. @@ -79,6 +84,9 @@ class P2PMessageDeduplicator(private val database: CordaPersistence) { * @return true if we have seen this message before. */ fun isDuplicate(msg: ReceivedMessage): Boolean { + if (msg.uniqueMessageId in beingProcessedMessages) { + return true + } val receivedSenderUUID = msg.senderUUID val receivedSenderSeqNo = msg.senderSeqNo // If we have received a new higher sequence number, then it cannot be a duplicate, and we don't need to check database. @@ -91,8 +99,26 @@ class P2PMessageDeduplicator(private val database: CordaPersistence) { } } - fun persistDeduplicationId(msg: ReceivedMessage) { - processedMessages[msg.uniqueMessageId] = Instant.now() + /** + * Called the first time we encounter [deduplicationId]. + */ + fun signalMessageProcessStart(deduplicationId: DeduplicationId) { + beingProcessedMessages.add(deduplicationId) + } + + /** + * Called inside a DB transaction to persist [deduplicationId]. + */ + fun persistDeduplicationId(deduplicationId: DeduplicationId) { + processedMessages[deduplicationId] = Instant.now() + } + + /** + * Called after the DB transaction persisting [deduplicationId] committed. + * Any subsequent redelivery will be deduplicated using the DB. + */ + fun signalMessageProcessFinish(deduplicationId: DeduplicationId) { + beingProcessedMessages.remove(deduplicationId) } @Entity diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 354b2edd95..a35cadfa85 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -431,6 +431,7 @@ class P2PMessagingClient(val config: NodeConfiguration, artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> if (!deduplicator.isDuplicate(cordaMessage)) { + deduplicator.signalMessageProcessStart(cordaMessage.uniqueMessageId) deliver(cordaMessage, artemisMessage) } else { log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" } @@ -444,7 +445,7 @@ class P2PMessagingClient(val config: NodeConfiguration, val deliverTo = handlers[msg.topic] if (deliverTo != null) { try { - deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandleFor(artemisMessage, msg)) + deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg)) } catch (e: Exception) { log.error("Caught exception whilst executing message handler for ${msg.topic}", e) } @@ -453,21 +454,18 @@ class P2PMessagingClient(val config: NodeConfiguration, } } - private fun acknowledgeHandleFor(artemisMessage: ClientMessage, cordaMessage: ReceivedMessage): AcknowledgeHandle { + inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler { + override fun insideDatabaseTransaction() { + deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId) + } - return object : AcknowledgeHandle { + override fun afterDatabaseTransaction() { + deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId) + messagingExecutor!!.acknowledge(artemisMessage) + } - override fun persistDeduplicationId() { - deduplicator.persistDeduplicationId(cordaMessage) - } - - // ACKing a message calls back into the session which isn't thread safe, so we have to ensure it - // doesn't collide with a send here. Note that stop() could have been called whilst we were - // processing a message but if so, it'll be parked waiting for us to count down the latch, so - // the session itself is still around and we can still ack messages as a result. - override fun acknowledge() { - messagingExecutor!!.acknowledge(artemisMessage) - } + override fun toString(): String { + return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})" } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt index a04dee7869..3e9dc3cb0e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -13,7 +13,7 @@ package net.corda.node.services.statemachine import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler import java.time.Instant /** @@ -55,14 +55,14 @@ sealed class Action { data class RemoveCheckpoint(val id: StateMachineRunId) : Action() /** - * Persist the deduplication IDs of [acknowledgeHandles]. + * Persist the deduplication facts of [deduplicationHandlers]. */ - data class PersistDeduplicationIds(val acknowledgeHandles: List) : Action() + data class PersistDeduplicationFacts(val deduplicationHandlers: List) : Action() /** - * Acknowledge messages in [acknowledgeHandles]. + * Acknowledge messages in [deduplicationHandlers]. */ - data class AcknowledgeMessages(val acknowledgeHandles: List) : Action() + data class AcknowledgeMessages(val deduplicationHandlers: List) : Action() /** * Propagate [errorMessages] to [sessions]. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index bf4d7b9d99..4b9cb88a06 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -71,7 +71,7 @@ class ActionExecutorImpl( return when (action) { is Action.TrackTransaction -> executeTrackTransaction(fiber, action) is Action.PersistCheckpoint -> executePersistCheckpoint(action) - is Action.PersistDeduplicationIds -> executePersistDeduplicationIds(action) + is Action.PersistDeduplicationFacts -> executePersistDeduplicationIds(action) is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action) is Action.PropagateErrors -> executePropagateErrors(action) is Action.ScheduleEvent -> executeScheduleEvent(fiber, action) @@ -113,16 +113,16 @@ class ActionExecutorImpl( } @Suspendable - private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationIds) { - for (handle in action.acknowledgeHandles) { - handle.persistDeduplicationId() + private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationFacts) { + for (handle in action.deduplicationHandlers) { + handle.insideDatabaseTransaction() } } @Suspendable private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) { - action.acknowledgeHandles.forEach { - it.acknowledge() + action.deduplicationHandlers.forEach { + it.afterDatabaseTransaction() } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt index a9861b23a4..9d18ebb18c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -15,7 +15,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.SignedTransaction -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler /** * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from @@ -33,12 +33,12 @@ sealed class Event { /** * Deliver a session message. * @param sessionMessage the message itself. - * @param acknowledgeHandle the handle to acknowledge the message after checkpointing. + * @param deduplicationHandler the handle to acknowledge the message after checkpointing. * @param sender the sender [Party]. */ data class DeliverSessionMessage( val sessionMessage: ExistingSessionMessage, - val acknowledgeHandle: AcknowledgeHandle, + val deduplicationHandler: DeduplicationHandler, val sender: Party ) : Event() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt index 017e8dda6b..496252a5f8 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -19,7 +19,7 @@ import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.trace import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.ReceivedMessage import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import java.io.NotSerializableException @@ -38,7 +38,7 @@ interface FlowMessaging { /** * Start the messaging using the [onMessage] message handler. */ - fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) + fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) } /** @@ -52,9 +52,9 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { val sessionTopic = "platform.session" } - override fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) { - serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, acknowledgeHandle -> - onMessage(receivedMessage, acknowledgeHandle) + override fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) { + serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, deduplicationHandler -> + onMessage(receivedMessage, deduplicationHandler) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt index 3d7e7c611f..7cc94b361e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt @@ -37,11 +37,10 @@ import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.shouldCheckCheckpoints -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.transitions.StateMachine -import net.corda.node.services.statemachine.transitions.StateMachineConfiguration import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl @@ -139,9 +138,9 @@ class MultiThreadedStateMachineManager( } serviceHub.networkMapCache.nodeReady.then { resumeRestoredFlows(fibers) - flowMessaging.start { receivedMessage, acknowledgeHandle -> + flowMessaging.start { receivedMessage, deduplicationHandler -> lifeCycle.requireState(State.STARTED) { - onSessionMessage(receivedMessage, acknowledgeHandle) + onSessionMessage(receivedMessage, deduplicationHandler) } } } @@ -202,7 +201,8 @@ class MultiThreadedStateMachineManager( override fun startFlow( flowLogic: FlowLogic, context: InvocationContext, - ourIdentity: Party? + ourIdentity: Party?, + deduplicationHandler: DeduplicationHandler? ): CordaFuture> { return lifeCycle.requireState(State.STARTED) { startFlowInternal( @@ -210,7 +210,7 @@ class MultiThreadedStateMachineManager( flowLogic = flowLogic, flowStart = FlowStart.Explicit, ourIdentity = ourIdentity ?: getOurFirstIdentity(), - initialUnacknowledgedMessage = null, + deduplicationHandler = deduplicationHandler, isStartIdempotent = false ) } @@ -320,7 +320,7 @@ class MultiThreadedStateMachineManager( createFlowFromCheckpoint( id = id, checkpoint = checkpoint, - initialUnacknowledgedMessage = null, + initialDeduplicationHandler = null, isAnyCheckpointPersisted = true, isStartIdempotent = false ) @@ -333,32 +333,32 @@ class MultiThreadedStateMachineManager( } } - private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { + private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) { val peer = message.peer val sessionMessage = try { message.data.deserialize() } catch (ex: Exception) { logger.error("Received corrupt SessionMessage data from $peer") - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() return } val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) if (sender != null) { when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) - is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender) } } else { logger.error("Unknown peer $peer in $sessionMessage") } } - private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) { try { val recipientId = sessionMessage.recipientSessionId val flowId = sessionToFlow[recipientId] if (flowId == null) { - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() if (sessionMessage.payload is EndSessionMessage) { logger.debug { "Got ${EndSessionMessage::class.java.simpleName} for " + @@ -369,7 +369,7 @@ class MultiThreadedStateMachineManager( } } else { val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") - flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) + flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)) } } catch (exception: Exception) { logger.error("Exception while routing $sessionMessage", exception) @@ -377,7 +377,7 @@ class MultiThreadedStateMachineManager( } } - private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) { fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { val errorId = secureRandom.nextLong() val payload = RejectSessionMessage(message, errorId) @@ -396,7 +396,7 @@ class MultiThreadedStateMachineManager( is InitiatedFlowFactory.Core -> senderPlatformVersion is InitiatedFlowFactory.CorDapp -> null } - startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) + startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) null } catch (exception: Exception) { logger.warn("Exception while creating initiated flow", exception) @@ -408,7 +408,7 @@ class MultiThreadedStateMachineManager( if (replyError != null) { flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() } } @@ -431,7 +431,7 @@ class MultiThreadedStateMachineManager( private fun startInitiatedFlow( flowLogic: FlowLogic, - triggeringUnacknowledgedMessage: AcknowledgeHandle, + initiatingMessageDeduplicationHandler: DeduplicationHandler, peerSession: FlowSessionImpl, initiatedSessionId: SessionId, initiatingMessage: InitialSessionMessage, @@ -442,7 +442,7 @@ class MultiThreadedStateMachineManager( val ourIdentity = getOurFirstIdentity() startFlowInternal( InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, - triggeringUnacknowledgedMessage, + initiatingMessageDeduplicationHandler, isStartIdempotent = false ) } @@ -452,7 +452,7 @@ class MultiThreadedStateMachineManager( flowLogic: FlowLogic, flowStart: FlowStart, ourIdentity: Party, - initialUnacknowledgedMessage: AcknowledgeHandle?, + deduplicationHandler: DeduplicationHandler?, isStartIdempotent: Boolean ): CordaFuture> { val flowId = StateMachineRunId.createRandom() @@ -475,7 +475,7 @@ class MultiThreadedStateMachineManager( val startedFuture = openFuture() val initialState = StateMachineState( checkpoint = initialCheckpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = false, @@ -532,7 +532,7 @@ class MultiThreadedStateMachineManager( checkpoint: Checkpoint, isAnyCheckpointPersisted: Boolean, isStartIdempotent: Boolean, - initialUnacknowledgedMessage: AcknowledgeHandle? + initialDeduplicationHandler: DeduplicationHandler? ): Flow { val flowState = checkpoint.flowState val resultFuture = openFuture() @@ -541,7 +541,7 @@ class MultiThreadedStateMachineManager( val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = isAnyCheckpointPersisted, @@ -559,7 +559,7 @@ class MultiThreadedStateMachineManager( val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = isAnyCheckpointPersisted, @@ -644,7 +644,7 @@ class MultiThreadedStateMachineManager( totalSuccessFlows.inc() drainFlowEventQueue(flow) // final sanity checks - require(lastState.unacknowledgedMessages.isEmpty()) + require(lastState.pendingDeduplicationHandlers.isEmpty()) require(lastState.isRemoved) require(lastState.checkpoint.subFlowStack.size == 1) sessionToFlow.none { it.value == flow.fiber.id } @@ -676,7 +676,7 @@ class MultiThreadedStateMachineManager( is Event.DoRemainingWork -> {} is Event.DeliverSessionMessage -> { // Acknowledge the message so it doesn't leak in the broker. - event.acknowledgeHandle.acknowledge() + event.deduplicationHandler.afterDatabaseTransaction() when (event.sessionMessage.payload) { EndSessionMessage -> { logger.debug { "Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down" } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index e3e0aa05a5..74b80db97f 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -40,11 +40,10 @@ import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.shouldCheckCheckpoints -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.transitions.StateMachine -import net.corda.node.services.statemachine.transitions.StateMachineConfiguration import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl @@ -136,9 +135,9 @@ class SingleThreadedStateMachineManager( } serviceHub.networkMapCache.nodeReady.then { resumeRestoredFlows(fibers) - flowMessaging.start { receivedMessage, acknowledgeHandle -> + flowMessaging.start { receivedMessage, deduplicationHandler -> executor.execute { - onSessionMessage(receivedMessage, acknowledgeHandle) + onSessionMessage(receivedMessage, deduplicationHandler) } } } @@ -203,14 +202,15 @@ class SingleThreadedStateMachineManager( override fun startFlow( flowLogic: FlowLogic, context: InvocationContext, - ourIdentity: Party? + ourIdentity: Party?, + deduplicationHandler: DeduplicationHandler? ): CordaFuture> { return startFlowInternal( invocationContext = context, flowLogic = flowLogic, flowStart = FlowStart.Explicit, ourIdentity = ourIdentity ?: getOurFirstIdentity(), - initialUnacknowledgedMessage = null, + deduplicationHandler = deduplicationHandler, isStartIdempotent = false ) } @@ -320,7 +320,7 @@ class SingleThreadedStateMachineManager( createFlowFromCheckpoint( id = id, checkpoint = checkpoint, - initialUnacknowledgedMessage = null, + initialDeduplicationHandler = null, isAnyCheckpointPersisted = true, isStartIdempotent = false ) @@ -333,32 +333,32 @@ class SingleThreadedStateMachineManager( } } - private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { + private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) { val peer = message.peer val sessionMessage = try { message.data.deserialize() } catch (ex: Exception) { logger.error("Received corrupt SessionMessage data from $peer") - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() return } val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) if (sender != null) { when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) - is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender) } } else { logger.error("Unknown peer $peer in $sessionMessage") } } - private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) { try { val recipientId = sessionMessage.recipientSessionId val flowId = sessionToFlow[recipientId] if (flowId == null) { - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() if (sessionMessage.payload is EndSessionMessage) { logger.debug { "Got ${EndSessionMessage::class.java.simpleName} for " + @@ -369,7 +369,7 @@ class SingleThreadedStateMachineManager( } } else { val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") - flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) + flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)) } } catch (exception: Exception) { logger.error("Exception while routing $sessionMessage", exception) @@ -377,7 +377,7 @@ class SingleThreadedStateMachineManager( } } - private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) { fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { val errorId = secureRandom.nextLong() val payload = RejectSessionMessage(message, errorId) @@ -396,7 +396,7 @@ class SingleThreadedStateMachineManager( is InitiatedFlowFactory.Core -> senderPlatformVersion is InitiatedFlowFactory.CorDapp -> null } - startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) + startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) null } catch (exception: Exception) { logger.warn("Exception while creating initiated flow", exception) @@ -408,7 +408,7 @@ class SingleThreadedStateMachineManager( if (replyError != null) { flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) - acknowledgeHandle.acknowledge() + deduplicationHandler.afterDatabaseTransaction() } } @@ -431,7 +431,7 @@ class SingleThreadedStateMachineManager( private fun startInitiatedFlow( flowLogic: FlowLogic, - triggeringUnacknowledgedMessage: AcknowledgeHandle, + initiatingMessageDeduplicationHandler: DeduplicationHandler, peerSession: FlowSessionImpl, initiatedSessionId: SessionId, initiatingMessage: InitialSessionMessage, @@ -442,7 +442,7 @@ class SingleThreadedStateMachineManager( val ourIdentity = getOurFirstIdentity() startFlowInternal( InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, - triggeringUnacknowledgedMessage, + initiatingMessageDeduplicationHandler, isStartIdempotent = false ) } @@ -452,7 +452,7 @@ class SingleThreadedStateMachineManager( flowLogic: FlowLogic, flowStart: FlowStart, ourIdentity: Party, - initialUnacknowledgedMessage: AcknowledgeHandle?, + deduplicationHandler: DeduplicationHandler?, isStartIdempotent: Boolean ): CordaFuture> { val flowId = StateMachineRunId.createRandom() @@ -475,7 +475,7 @@ class SingleThreadedStateMachineManager( val startedFuture = openFuture() val initialState = StateMachineState( checkpoint = initialCheckpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = false, @@ -532,7 +532,7 @@ class SingleThreadedStateMachineManager( checkpoint: Checkpoint, isAnyCheckpointPersisted: Boolean, isStartIdempotent: Boolean, - initialUnacknowledgedMessage: AcknowledgeHandle? + initialDeduplicationHandler: DeduplicationHandler? ): Flow { val flowState = checkpoint.flowState val resultFuture = openFuture() @@ -541,7 +541,7 @@ class SingleThreadedStateMachineManager( val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = isAnyCheckpointPersisted, @@ -559,7 +559,7 @@ class SingleThreadedStateMachineManager( val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, - unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), isFlowResumed = false, isTransactionTracked = false, isAnyCheckpointPersisted = isAnyCheckpointPersisted, @@ -648,7 +648,7 @@ class SingleThreadedStateMachineManager( ) { drainFlowEventQueue(flow) // final sanity checks - require(lastState.unacknowledgedMessages.isEmpty()) + require(lastState.pendingDeduplicationHandlers.isEmpty()) require(lastState.isRemoved) require(lastState.checkpoint.subFlowStack.size == 1) sessionToFlow.none { it.value == flow.fiber.id } @@ -679,7 +679,7 @@ class SingleThreadedStateMachineManager( is Event.DoRemainingWork -> {} is Event.DeliverSessionMessage -> { // Acknowledge the message so it doesn't leak in the broker. - event.acknowledgeHandle.acknowledge() + event.deduplicationHandler.afterDatabaseTransaction() when (event.sessionMessage.payload) { EndSessionMessage -> { logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 14cc73367e..51e12e57db 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -18,6 +18,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.DataFeed import net.corda.core.utilities.Try +import net.corda.node.services.messaging.DeduplicationHandler import rx.Observable /** @@ -58,11 +59,14 @@ interface StateMachineManager { * * @param flowLogic The flow's code. * @param context The context of the flow. + * @param ourIdentity The identity to use for the flow. + * @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler]. */ fun startFlow( flowLogic: FlowLogic, context: InvocationContext, - ourIdentity: Party? = null + ourIdentity: Party?, + deduplicationHandler: DeduplicationHandler? ): CordaFuture> /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index 2efb8fcc69..fafdf5533e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -17,7 +17,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.Try -import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.DeduplicationHandler /** * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is @@ -25,7 +25,7 @@ import net.corda.node.services.messaging.AcknowledgeHandle * * @param checkpoint the persisted part of the state. * @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user. - * @param unacknowledgedMessages the list of currently unacknowledged messages. + * @param pendingDeduplicationHandlers the list of incomplete deduplication handlers. * @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used * to make [Event.DoRemainingWork] idempotent. * @param isTransactionTracked true if a ledger transaction has been tracked as part of a @@ -42,7 +42,7 @@ import net.corda.node.services.messaging.AcknowledgeHandle data class StateMachineState( val checkpoint: Checkpoint, val flowLogic: FlowLogic<*>, - val unacknowledgedMessages: List, + val pendingDeduplicationHandlers: List, val isFlowResumed: Boolean, val isTransactionTracked: Boolean, val isAnyCheckpointPersisted: Boolean, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt index 0d110c2292..303eb60527 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -34,12 +34,12 @@ class DeliverSessionMessageTransition( ) : Transition { override fun transition(): TransitionResult { return builder { - // Add the AcknowledgeHandle to the unacknowledged messages ASAP so in case an error happens we still know + // Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know // about the message. Note that in case of an error during deliver this message *will be acked*. // For example if the session corresponding to the message is not found the message is still acked to free // up the broker but the flow will error. currentState = currentState.copy( - unacknowledgedMessages = currentState.unacknowledgedMessages + event.acknowledgeHandle + pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler ) // Check whether we have a session corresponding to the message. val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId] @@ -163,12 +163,12 @@ class DeliverSessionMessageTransition( actions.addAll(arrayOf( Action.CreateTransaction, Action.PersistCheckpoint(context.id, currentState.checkpoint), - Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.unacknowledgedMessages) + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers) )) currentState = currentState.copy( - unacknowledgedMessages = emptyList(), + pendingDeduplicationHandlers = emptyList(), isAnyCheckpointPersisted = true ) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index 08b5018c37..d07963a559 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -71,14 +71,14 @@ class ErrorFlowTransition( actions.add(Action.RemoveCheckpoint(context.id)) } actions.addAll(arrayOf( - Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.unacknowledgedMessages), + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) )) currentState = currentState.copy( - unacknowledgedMessages = emptyList(), + pendingDeduplicationHandlers = emptyList(), isRemoved = true ) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index f78debc99a..ce413287d1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -151,14 +151,14 @@ class TopLevelTransition( } else { actions.addAll(arrayOf( Action.PersistCheckpoint(context.id, newCheckpoint), - Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.unacknowledgedMessages), + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), Action.ScheduleEvent(Event.DoRemainingWork) )) currentState = currentState.copy( checkpoint = newCheckpoint, - unacknowledgedMessages = emptyList(), + pendingDeduplicationHandlers = emptyList(), isFlowResumed = false, isAnyCheckpointPersisted = true ) @@ -172,12 +172,12 @@ class TopLevelTransition( val checkpoint = currentState.checkpoint when (checkpoint.errorState) { ErrorState.Clean -> { - val unacknowledgedMessages = currentState.unacknowledgedMessages + val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers currentState = currentState.copy( checkpoint = checkpoint.copy( numberOfSuspends = checkpoint.numberOfSuspends + 1 ), - unacknowledgedMessages = emptyList(), + pendingDeduplicationHandlers = emptyList(), isFlowResumed = false, isRemoved = true ) @@ -186,9 +186,9 @@ class TopLevelTransition( actions.add(Action.RemoveCheckpoint(context.id)) } actions.addAll(arrayOf( - Action.PersistDeduplicationIds(unacknowledgedMessages), + Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), Action.CommitTransaction, - Action.AcknowledgeMessages(unacknowledgedMessages), + Action.AcknowledgeMessages(pendingDeduplicationHandlers), Action.RemoveSessionBindings(allSourceSessionIds), Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState) )) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt index 9278e5aa7c..e347bb1de6 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -78,12 +78,12 @@ class UnstartedFlowTransition( actions.addAll(arrayOf( Action.CreateTransaction, Action.PersistCheckpoint(context.id, currentState.checkpoint), - Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.unacknowledgedMessages) + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers) )) currentState = currentState.copy( - unacknowledgedMessages = emptyList(), + pendingDeduplicationHandlers = emptyList(), isAnyCheckpointPersisted = true ) } diff --git a/node/src/main/resources/reference.conf b/node/src/main/resources/reference.conf index a0cf277f42..d460aa1eba 100644 --- a/node/src/main/resources/reference.conf +++ b/node/src/main/resources/reference.conf @@ -4,7 +4,7 @@ keyStorePassword = "cordacadevpass" trustStorePassword = "trustpass" dataSourceProperties = { dataSourceClassName = org.h2.jdbcx.JdbcDataSource - dataSource.url = "jdbc:h2:file:"${baseDirectory}"/persistence;DB_CLOSE_ON_EXIT=FALSE;LOCK_TIMEOUT=10000;WRITE_DELAY=100;AUTO_SERVER_PORT="${h2port} + dataSource.url = "jdbc:h2:file:"${baseDirectory}"/persistence;DB_CLOSE_ON_EXIT=FALSE;LOCK_TIMEOUT=10000;WRITE_DELAY=0;AUTO_SERVER_PORT="${h2port} dataSource.user = sa dataSource.password = "" } diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 97652b1c70..10dbec15ea 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -48,7 +48,7 @@ class NodeSchedulerServiceTest { }.whenever(it).transaction(any()) } private val flowStarter = rigorousMock().also { - doReturn(openFuture>()).whenever(it).startFlow(any>(), any()) + doReturn(openFuture>()).whenever(it).startFlow(any>(), any(), any()) } private val flowsDraingMode = rigorousMock().also { doReturn(false).whenever(it).isEnabled() @@ -111,7 +111,7 @@ class NodeSchedulerServiceTest { private fun assertStarted(event: Event) { // Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow: - verify(flowStarter, timeout(5000)).startFlow(same(event.flowLogic)!!, any()) + verify(flowStarter, timeout(5000)).startFlow(same(event.flowLogic)!!, any(), any()) } @Test diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt index 261b6fc4fa..f653517196 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt @@ -217,8 +217,8 @@ class ArtemisMessagingTest { try { val messagingClient2 = createMessagingClient() messagingClient2.addMessageHandler(TOPIC) { msg, _, handle -> - database.transaction { handle.persistDeduplicationId() } - handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] + database.transaction { handle.insideDatabaseTransaction() } + handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] receivedMessages.add(msg) } startNodeMessagingClient() @@ -252,8 +252,8 @@ class ArtemisMessagingTest { val messagingClient3 = createMessagingClient() messagingClient3.addMessageHandler(TOPIC) { msg, _, handle -> - database.transaction { handle.persistDeduplicationId() } - handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] + database.transaction { handle.insideDatabaseTransaction() } + handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] receivedMessages.add(msg) } startNodeMessagingClient() @@ -281,8 +281,8 @@ class ArtemisMessagingTest { val messagingClient = createMessagingClient(platformVersion = platformVersion) messagingClient.addMessageHandler(TOPIC) { message, _, handle -> - database.transaction { handle.persistDeduplicationId() } - handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] + database.transaction { handle.insideDatabaseTransaction() } + handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] receivedMessages.add(message) } startNodeMessagingClient() diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index 883c29c1ea..d99836853f 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -480,13 +480,7 @@ class InMemoryMessagingNetwork private constructor( database.transaction { for (handler in deliverTo) { try { - val acknowledgeHandle = object : AcknowledgeHandle { - override fun acknowledge() { - } - override fun persistDeduplicationId() { - } - } - handler.callback(transfer.toReceivedMessage(), handler, acknowledgeHandle) + handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler()) } catch (e: Exception) { log.error("Caught exception in handler for $this/${handler.topicSession}", e) } @@ -510,5 +504,12 @@ class InMemoryMessagingNetwork private constructor( message.debugTimestamp, sender.name) } + + private class DummyDeduplicationHandler : DeduplicationHandler { + override fun afterDatabaseTransaction() { + } + override fun insideDatabaseTransaction() { + } + } }