Merge pull request #572 from corda/aslemmer-merge-DP3-bla

Squash of changes on DP3
This commit is contained in:
Andras Slemmer 2018-03-16 17:57:26 +00:00 committed by GitHub
commit 908a3badf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 370 additions and 148 deletions

View File

@ -3996,10 +3996,6 @@ public static final class net.corda.testing.node.InMemoryMessagingNetwork$Compan
public int hashCode() public int hashCode()
@org.jetbrains.annotations.NotNull public String toString() @org.jetbrains.annotations.NotNull public String toString()
## ##
public static final class net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessaging$pumpReceiveInternal$1$1$acknowledgeHandle$1 extends java.lang.Object implements net.corda.node.services.messaging.AcknowledgeHandle
public void acknowledge()
public void persistDeduplicationId()
##
public static interface net.corda.testing.node.InMemoryMessagingNetwork$LatencyCalculator public static interface net.corda.testing.node.InMemoryMessagingNetwork$LatencyCalculator
@org.jetbrains.annotations.NotNull public abstract java.time.Duration between(net.corda.core.messaging.SingleMessageRecipient, net.corda.core.messaging.SingleMessageRecipient) @org.jetbrains.annotations.NotNull public abstract java.time.Duration between(net.corda.core.messaging.SingleMessageRecipient, net.corda.core.messaging.SingleMessageRecipient)
## ##

View File

@ -0,0 +1,162 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.testing.core.DUMMY_BANK_A_NAME
import net.corda.testing.core.DUMMY_BANK_B_NAME
import net.corda.testing.core.singleIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.OutOfProcess
import net.corda.testing.driver.driver
import net.corda.testing.node.User
import org.junit.Test
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import kotlin.concurrent.thread
class HardRestartTest {
@StartableByRPC
@InitiatingFlow
class Ping(val pongParty: Party) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val pongSession = initiateFlow(pongParty)
pongSession.sendAndReceive<Unit>(Unit)
}
}
@InitiatedBy(Ping::class)
class Pong(val pingSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
pingSession.sendAndReceive<Unit>(Unit)
}
}
@Test
fun restartPingPongFlowRandomly() {
val demoUser = User("demo", "demo", setOf(Permissions.startFlow<Ping>(), Permissions.all()))
driver(DriverParameters(isDebug = true, startNodesInProcess = false)) {
val (a, b) = listOf(
startNode(providedName = DUMMY_BANK_A_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:30000")),
startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000"))
).transpose().getOrThrow()
val latch = CountDownLatch(1)
// We kill -9 and restart the Pong node after a random sleep
val pongRestartThread = thread {
latch.await()
val ms = Random().nextInt(1000)
println("Sleeping $ms ms before kill")
Thread.sleep(ms.toLong())
(b as OutOfProcess).process.destroyForcibly()
b.stop()
startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000"))
}
CordaRPCClient(a.rpcAddress).use(demoUser.username, demoUser.password) {
val returnValue = it.proxy.startFlow(::Ping, b.nodeInfo.singleIdentity()).returnValue
latch.countDown()
// No matter the kill
returnValue.getOrThrow()
}
pongRestartThread.join()
}
}
sealed class RecursiveMode {
data class Top(val otherParty: Party, val initialDepth: Int) : RecursiveMode()
data class Recursive(val otherSession: FlowSession) : RecursiveMode()
}
@StartableByRPC
@InitiatingFlow
@InitiatedBy(RecursiveB::class)
class RecursiveA(val mode: RecursiveMode) : FlowLogic<String>() {
constructor(otherSession: FlowSession) : this(RecursiveMode.Recursive(otherSession))
constructor(otherParty: Party, initialDepth: Int) : this(RecursiveMode.Top(otherParty, initialDepth))
@Suspendable
override fun call(): String {
return when (mode) {
is HardRestartTest.RecursiveMode.Top -> {
val session = initiateFlow(mode.otherParty)
session.sendAndReceive<String>(mode.initialDepth).unwrap { it }
}
is HardRestartTest.RecursiveMode.Recursive -> {
val depth = mode.otherSession.receive<Int>().unwrap { it }
val string = if (depth > 0) {
val newSession = initiateFlow(mode.otherSession.counterparty)
newSession.sendAndReceive<String>(depth).unwrap { it }
} else {
"-"
}
mode.otherSession.send(string)
string
}
}
}
}
@InitiatingFlow
@InitiatedBy(RecursiveA::class)
class RecursiveB(val otherSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val depth = otherSession.receive<Int>().unwrap { it }
val newSession = initiateFlow(otherSession.counterparty)
val string = newSession.sendAndReceive<String>(depth - 1).unwrap { it }
otherSession.send(string + ":" + depth)
}
}
@Test
fun restartRecursiveFlowRandomly() {
val demoUser = User("demo", "demo", setOf(Permissions.startFlow<RecursiveA>(), Permissions.all()))
driver(DriverParameters(isDebug = true, startNodesInProcess = false)) {
val (a, b) = listOf(
startNode(providedName = DUMMY_BANK_A_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:30000")),
startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000"))
).transpose().getOrThrow()
val latch = CountDownLatch(1)
// We kill -9 and restart the node B after a random sleep
val bRestartThread = thread {
latch.await()
val ms = Random().nextInt(1000)
println("Sleeping $ms ms before kill")
Thread.sleep(ms.toLong())
(b as OutOfProcess).process.destroyForcibly()
b.stop()
startNode(providedName = DUMMY_BANK_B_NAME, rpcUsers = listOf(demoUser), customOverrides = mapOf("p2pAddress" to "localhost:40000"))
}
val executor = Executors.newFixedThreadPool(8)
try {
val tlRpc = ThreadLocal<CordaRPCOps>()
(1 .. 10).map { num ->
executor.fork {
val rpc = tlRpc.get() ?: CordaRPCClient(a.rpcAddress).start(demoUser.username, demoUser.password).proxy.also { tlRpc.set(it) }
val string = rpc.startFlow(::RecursiveA, b.nodeInfo.singleIdentity(), 10).returnValue.getOrThrow()
latch.countDown()
println("$num: $string")
}
}.transpose().getOrThrow()
bRestartThread.join()
} finally {
executor.shutdown()
}
}
}
}

View File

@ -180,7 +180,7 @@ class P2PMessagingTest : IntegrationTest() {
val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes) val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes)
it.internalServices.networkService.send(response, request.replyTo) it.internalServices.networkService.send(response, request.replyTo)
} }
handler.acknowledge() handler.afterDatabaseTransaction()
} }
} }
return crashingNodes return crashingNodes
@ -212,7 +212,7 @@ class P2PMessagingTest : IntegrationTest() {
val request = netMessage.data.deserialize<TestRequest>() val request = netMessage.data.deserialize<TestRequest>()
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes) val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes)
internalServices.networkService.send(response, request.replyTo) internalServices.networkService.send(response, request.replyTo)
handle.acknowledge() handle.afterDatabaseTransaction()
} }
} }
@ -239,7 +239,7 @@ class P2PMessagingTest : IntegrationTest() {
check(!consumed.getAndSet(true)) { "Called more than once" } check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" } check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
callback(msg) callback(msg)
handle.acknowledge() handle.afterDatabaseTransaction()
} }
} }

View File

@ -57,6 +57,7 @@ import net.corda.node.services.events.NodeSchedulerService
import net.corda.node.services.events.ScheduledActivityObserver import net.corda.node.services.events.ScheduledActivityObserver
import net.corda.node.services.identity.PersistentIdentityService import net.corda.node.services.identity.PersistentIdentityService
import net.corda.node.services.keys.PersistentKeyManagementService import net.corda.node.services.keys.PersistentKeyManagementService
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.* import net.corda.node.services.network.*
import net.corda.node.services.persistence.* import net.corda.node.services.persistence.*
@ -74,7 +75,6 @@ import net.corda.nodeapi.internal.DevIdentityGenerator
import net.corda.nodeapi.internal.NodeInfoAndSigned import net.corda.nodeapi.internal.NodeInfoAndSigned
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.persistence.* import net.corda.nodeapi.internal.persistence.*
import net.corda.nodeapi.internal.sign
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.HibernateConfiguration import net.corda.nodeapi.internal.persistence.HibernateConfiguration
@ -873,8 +873,8 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
} }
internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> { override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture<FlowStateMachine<T>> {
return smm.startFlow(logic, context) return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler)
} }
override fun <T> invokeFlowAsync( override fun <T> invokeFlowAsync(

View File

@ -29,6 +29,7 @@ import net.corda.core.utilities.contextLogger
import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.cordapp.CordappProviderInternal import net.corda.node.internal.cordapp.CordappProviderInternal
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.NetworkMapUpdater import net.corda.node.services.network.NetworkMapUpdater
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
@ -137,8 +138,9 @@ interface FlowStarter {
/** /**
* Starts an already constructed flow. Note that you must be on the server thread to call this method. * Starts an already constructed flow. Note that you must be on the server thread to call this method.
* @param context indicates who started the flow, see: [InvocationContext]. * @param context indicates who started the flow, see: [InvocationContext].
* @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler]
*/ */
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture<FlowStateMachine<T>>
/** /**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.

View File

@ -36,10 +36,12 @@ import net.corda.node.MutableClock
import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.NodePropertiesStore import net.corda.node.services.api.NodePropertiesStore
import net.corda.node.services.api.SchedulerService import net.corda.node.services.api.SchedulerService
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.utilities.PersistentMap import net.corda.node.utilities.PersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.apache.mina.util.ConcurrentHashSet
import org.slf4j.Logger import org.slf4j.Logger
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
@ -172,6 +174,10 @@ class NodeSchedulerService(private val clock: CordaClock,
var rescheduled: GuavaSettableFuture<Boolean>? = null var rescheduled: GuavaSettableFuture<Boolean>? = null
} }
// Used to de-duplicate flow starts in case a flow is starting but the corresponding entry hasn't been removed yet
// from the database
private val startingStateRefs = ConcurrentHashSet<ScheduledStateRef>()
private val mutex = ThreadBox(InnerState()) private val mutex = ThreadBox(InnerState())
// We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow. // We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow.
fun start() { fun start() {
@ -212,7 +218,7 @@ class NodeSchedulerService(private val clock: CordaClock,
val previousEarliest = scheduledStatesQueue.peek() val previousEarliest = scheduledStatesQueue.peek()
scheduledStatesQueue.remove(previousState) scheduledStatesQueue.remove(previousState)
scheduledStatesQueue.add(action) scheduledStatesQueue.add(action)
if (previousState == null) { if (previousState == null && action !in startingStateRefs) {
unfinishedSchedules.countUp() unfinishedSchedules.countUp()
} }
@ -279,16 +285,34 @@ class NodeSchedulerService(private val clock: CordaClock,
schedulerTimerExecutor.join() schedulerTimerExecutor.join()
} }
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler {
override fun insideDatabaseTransaction() {
scheduledStates.remove(scheduledState.ref)
}
override fun afterDatabaseTransaction() {
startingStateRefs.remove(scheduledState)
}
override fun toString(): String {
return "${javaClass.simpleName}($scheduledState)"
}
}
private fun onTimeReached(scheduledState: ScheduledStateRef) { private fun onTimeReached(scheduledState: ScheduledStateRef) {
var flowName: String? = "(unknown)" var flowName: String? = "(unknown)"
try { try {
database.transaction { // We need to check this before the database transaction, otherwise there is a subtle race between a
val scheduledFlow = getScheduledFlow(scheduledState) // doubly-reached deadline and the removal from [startingStateRefs].
if (scheduledState !in startingStateRefs) {
val scheduledFlow = database.transaction { getScheduledFlow(scheduledState) }
if (scheduledFlow != null) { if (scheduledFlow != null) {
startingStateRefs.add(scheduledState)
flowName = scheduledFlow.javaClass.name flowName = scheduledFlow.javaClass.name
// TODO refactor the scheduler to store and propagate the original invocation context // TODO refactor the scheduler to store and propagate the original invocation context
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState)) val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState)
val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture }
future.then { future.then {
unfinishedSchedules.countDown() unfinishedSchedules.countDown()
} }
@ -327,7 +351,6 @@ class NodeSchedulerService(private val clock: CordaClock,
} }
else -> { else -> {
log.trace { "Scheduler starting FlowLogic $flowLogic" } log.trace { "Scheduler starting FlowLogic $flowLogic" }
scheduledStates.remove(scheduledState.ref)
scheduledStatesQueue.remove(scheduledState) scheduledStatesQueue.remove(scheduledState)
flowLogic flowLogic
} }

View File

@ -163,19 +163,29 @@ object TopicStringValidator {
} }
/** /**
* Represents a to-be-acknowledged message. It has an associated deduplication ID. * This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done
* using two hooks that are called from the event processor, one called from the database transaction committing the
* side-effect caused by the event, and another one called after the transaction has committed successfully.
*
* For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later
* deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries.
*
* We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the
* to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping.
*/ */
interface AcknowledgeHandle { interface DeduplicationHandler {
/** /**
* Acknowledge the message. * This will be run inside a database transaction that commits the side-effect of the event, allowing the
* implementor to persist the event delivery fact atomically with the side-effect.
*/ */
fun acknowledge() fun insideDatabaseTransaction()
/** /**
* Store the deduplication ID. TODO this should be moved into the flow state machine completely. * This will be run strictly after the side-effect has been committed successfully and may be used for
* cleanup/acknowledgement/stopping of retries.
*/ */
fun persistDeduplicationId() fun afterDatabaseTransaction()
} }
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, AcknowledgeHandle) -> Unit typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit

View File

@ -16,6 +16,7 @@ import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.apache.mina.util.ConcurrentHashSet
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -29,6 +30,10 @@ import javax.persistence.Id
class P2PMessageDeduplicator(private val database: CordaPersistence) { class P2PMessageDeduplicator(private val database: CordaPersistence) {
val ourSenderUUID = UUID.randomUUID().toString() val ourSenderUUID = UUID.randomUUID().toString()
// A temporary in-memory set of deduplication IDs. When we receive a message we don't persist the ID immediately,
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
// redeliver messages to the same consumer if they weren't ACKed.
private val beingProcessedMessages = ConcurrentHashSet<DeduplicationId>()
private val processedMessages = createProcessedMessages() private val processedMessages = createProcessedMessages()
// We add the peer to the key, so other peers cannot attempt malicious meddling with sequence numbers. // We add the peer to the key, so other peers cannot attempt malicious meddling with sequence numbers.
// Expire after 7 days since we last touched an entry, to avoid infinite growth. // Expire after 7 days since we last touched an entry, to avoid infinite growth.
@ -79,6 +84,9 @@ class P2PMessageDeduplicator(private val database: CordaPersistence) {
* @return true if we have seen this message before. * @return true if we have seen this message before.
*/ */
fun isDuplicate(msg: ReceivedMessage): Boolean { fun isDuplicate(msg: ReceivedMessage): Boolean {
if (msg.uniqueMessageId in beingProcessedMessages) {
return true
}
val receivedSenderUUID = msg.senderUUID val receivedSenderUUID = msg.senderUUID
val receivedSenderSeqNo = msg.senderSeqNo val receivedSenderSeqNo = msg.senderSeqNo
// If we have received a new higher sequence number, then it cannot be a duplicate, and we don't need to check database. // If we have received a new higher sequence number, then it cannot be a duplicate, and we don't need to check database.
@ -91,8 +99,26 @@ class P2PMessageDeduplicator(private val database: CordaPersistence) {
} }
} }
fun persistDeduplicationId(msg: ReceivedMessage) { /**
processedMessages[msg.uniqueMessageId] = Instant.now() * Called the first time we encounter [deduplicationId].
*/
fun signalMessageProcessStart(deduplicationId: DeduplicationId) {
beingProcessedMessages.add(deduplicationId)
}
/**
* Called inside a DB transaction to persist [deduplicationId].
*/
fun persistDeduplicationId(deduplicationId: DeduplicationId) {
processedMessages[deduplicationId] = Instant.now()
}
/**
* Called after the DB transaction persisting [deduplicationId] committed.
* Any subsequent redelivery will be deduplicated using the DB.
*/
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) {
beingProcessedMessages.remove(deduplicationId)
} }
@Entity @Entity

View File

@ -431,6 +431,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
if (!deduplicator.isDuplicate(cordaMessage)) { if (!deduplicator.isDuplicate(cordaMessage)) {
deduplicator.signalMessageProcessStart(cordaMessage.uniqueMessageId)
deliver(cordaMessage, artemisMessage) deliver(cordaMessage, artemisMessage)
} else { } else {
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" } log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" }
@ -444,7 +445,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
val deliverTo = handlers[msg.topic] val deliverTo = handlers[msg.topic]
if (deliverTo != null) { if (deliverTo != null) {
try { try {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandleFor(artemisMessage, msg)) deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg))
} catch (e: Exception) { } catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e) log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
} }
@ -453,21 +454,18 @@ class P2PMessagingClient(val config: NodeConfiguration,
} }
} }
private fun acknowledgeHandleFor(artemisMessage: ClientMessage, cordaMessage: ReceivedMessage): AcknowledgeHandle { inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler {
override fun insideDatabaseTransaction() {
return object : AcknowledgeHandle { deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId)
override fun persistDeduplicationId() {
deduplicator.persistDeduplicationId(cordaMessage)
} }
// ACKing a message calls back into the session which isn't thread safe, so we have to ensure it override fun afterDatabaseTransaction() {
// doesn't collide with a send here. Note that stop() could have been called whilst we were deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId)
// processing a message but if so, it'll be parked waiting for us to count down the latch, so
// the session itself is still around and we can still ack messages as a result.
override fun acknowledge() {
messagingExecutor!!.acknowledge(artemisMessage) messagingExecutor!!.acknowledge(artemisMessage)
} }
override fun toString(): String {
return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})"
} }
} }

View File

@ -13,7 +13,7 @@ package net.corda.node.services.statemachine
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
import java.time.Instant import java.time.Instant
/** /**
@ -55,14 +55,14 @@ sealed class Action {
data class RemoveCheckpoint(val id: StateMachineRunId) : Action() data class RemoveCheckpoint(val id: StateMachineRunId) : Action()
/** /**
* Persist the deduplication IDs of [acknowledgeHandles]. * Persist the deduplication facts of [deduplicationHandlers].
*/ */
data class PersistDeduplicationIds(val acknowledgeHandles: List<AcknowledgeHandle>) : Action() data class PersistDeduplicationFacts(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/** /**
* Acknowledge messages in [acknowledgeHandles]. * Acknowledge messages in [deduplicationHandlers].
*/ */
data class AcknowledgeMessages(val acknowledgeHandles: List<AcknowledgeHandle>) : Action() data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/** /**
* Propagate [errorMessages] to [sessions]. * Propagate [errorMessages] to [sessions].

View File

@ -71,7 +71,7 @@ class ActionExecutorImpl(
return when (action) { return when (action) {
is Action.TrackTransaction -> executeTrackTransaction(fiber, action) is Action.TrackTransaction -> executeTrackTransaction(fiber, action)
is Action.PersistCheckpoint -> executePersistCheckpoint(action) is Action.PersistCheckpoint -> executePersistCheckpoint(action)
is Action.PersistDeduplicationIds -> executePersistDeduplicationIds(action) is Action.PersistDeduplicationFacts -> executePersistDeduplicationIds(action)
is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action) is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action)
is Action.PropagateErrors -> executePropagateErrors(action) is Action.PropagateErrors -> executePropagateErrors(action)
is Action.ScheduleEvent -> executeScheduleEvent(fiber, action) is Action.ScheduleEvent -> executeScheduleEvent(fiber, action)
@ -113,16 +113,16 @@ class ActionExecutorImpl(
} }
@Suspendable @Suspendable
private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationIds) { private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationFacts) {
for (handle in action.acknowledgeHandles) { for (handle in action.deduplicationHandlers) {
handle.persistDeduplicationId() handle.insideDatabaseTransaction()
} }
} }
@Suspendable @Suspendable
private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) { private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) {
action.acknowledgeHandles.forEach { action.deduplicationHandlers.forEach {
it.acknowledge() it.afterDatabaseTransaction()
} }
} }

View File

@ -15,7 +15,7 @@ import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
/** /**
* Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from
@ -33,12 +33,12 @@ sealed class Event {
/** /**
* Deliver a session message. * Deliver a session message.
* @param sessionMessage the message itself. * @param sessionMessage the message itself.
* @param acknowledgeHandle the handle to acknowledge the message after checkpointing. * @param deduplicationHandler the handle to acknowledge the message after checkpointing.
* @param sender the sender [Party]. * @param sender the sender [Party].
*/ */
data class DeliverSessionMessage( data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage, val sessionMessage: ExistingSessionMessage,
val acknowledgeHandle: AcknowledgeHandle, val deduplicationHandler: DeduplicationHandler,
val sender: Party val sender: Party
) : Event() ) : Event()

View File

@ -19,7 +19,7 @@ import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import java.io.NotSerializableException import java.io.NotSerializableException
@ -38,7 +38,7 @@ interface FlowMessaging {
/** /**
* Start the messaging using the [onMessage] message handler. * Start the messaging using the [onMessage] message handler.
*/ */
fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
} }
/** /**
@ -52,9 +52,9 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
val sessionTopic = "platform.session" val sessionTopic = "platform.session"
} }
override fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) { override fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) {
serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, acknowledgeHandle -> serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, deduplicationHandler ->
onMessage(receivedMessage, acknowledgeHandle) onMessage(receivedMessage, deduplicationHandler)
} }
} }

View File

@ -37,11 +37,10 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.interceptors.*
import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.services.statemachine.transitions.StateMachineConfiguration
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
@ -139,9 +138,9 @@ class MultiThreadedStateMachineManager(
} }
serviceHub.networkMapCache.nodeReady.then { serviceHub.networkMapCache.nodeReady.then {
resumeRestoredFlows(fibers) resumeRestoredFlows(fibers)
flowMessaging.start { receivedMessage, acknowledgeHandle -> flowMessaging.start { receivedMessage, deduplicationHandler ->
lifeCycle.requireState(State.STARTED) { lifeCycle.requireState(State.STARTED) {
onSessionMessage(receivedMessage, acknowledgeHandle) onSessionMessage(receivedMessage, deduplicationHandler)
} }
} }
} }
@ -202,7 +201,8 @@ class MultiThreadedStateMachineManager(
override fun <A> startFlow( override fun <A> startFlow(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
context: InvocationContext, context: InvocationContext,
ourIdentity: Party? ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>> { ): CordaFuture<FlowStateMachine<A>> {
return lifeCycle.requireState(State.STARTED) { return lifeCycle.requireState(State.STARTED) {
startFlowInternal( startFlowInternal(
@ -210,7 +210,7 @@ class MultiThreadedStateMachineManager(
flowLogic = flowLogic, flowLogic = flowLogic,
flowStart = FlowStart.Explicit, flowStart = FlowStart.Explicit,
ourIdentity = ourIdentity ?: getOurFirstIdentity(), ourIdentity = ourIdentity ?: getOurFirstIdentity(),
initialUnacknowledgedMessage = null, deduplicationHandler = deduplicationHandler,
isStartIdempotent = false isStartIdempotent = false
) )
} }
@ -320,7 +320,7 @@ class MultiThreadedStateMachineManager(
createFlowFromCheckpoint( createFlowFromCheckpoint(
id = id, id = id,
checkpoint = checkpoint, checkpoint = checkpoint,
initialUnacknowledgedMessage = null, initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true, isAnyCheckpointPersisted = true,
isStartIdempotent = false isStartIdempotent = false
) )
@ -333,32 +333,32 @@ class MultiThreadedStateMachineManager(
} }
} }
private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
val peer = message.peer val peer = message.peer
val sessionMessage = try { val sessionMessage = try {
message.data.deserialize<SessionMessage>() message.data.deserialize<SessionMessage>()
} catch (ex: Exception) { } catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from $peer") logger.error("Received corrupt SessionMessage data from $peer")
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
return return
} }
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender)
} }
} else { } else {
logger.error("Unknown peer $peer in $sessionMessage") logger.error("Unknown peer $peer in $sessionMessage")
} }
} }
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) {
try { try {
val recipientId = sessionMessage.recipientSessionId val recipientId = sessionMessage.recipientSessionId
val flowId = sessionToFlow[recipientId] val flowId = sessionToFlow[recipientId]
if (flowId == null) { if (flowId == null) {
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
if (sessionMessage.payload is EndSessionMessage) { if (sessionMessage.payload is EndSessionMessage) {
logger.debug { logger.debug {
"Got ${EndSessionMessage::class.java.simpleName} for " + "Got ${EndSessionMessage::class.java.simpleName} for " +
@ -369,7 +369,7 @@ class MultiThreadedStateMachineManager(
} }
} else { } else {
val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
} }
} catch (exception: Exception) { } catch (exception: Exception) {
logger.error("Exception while routing $sessionMessage", exception) logger.error("Exception while routing $sessionMessage", exception)
@ -377,7 +377,7 @@ class MultiThreadedStateMachineManager(
} }
} }
private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) {
fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage {
val errorId = secureRandom.nextLong() val errorId = secureRandom.nextLong()
val payload = RejectSessionMessage(message, errorId) val payload = RejectSessionMessage(message, errorId)
@ -396,7 +396,7 @@ class MultiThreadedStateMachineManager(
is InitiatedFlowFactory.Core -> senderPlatformVersion is InitiatedFlowFactory.Core -> senderPlatformVersion
is InitiatedFlowFactory.CorDapp -> null is InitiatedFlowFactory.CorDapp -> null
} }
startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo)
null null
} catch (exception: Exception) { } catch (exception: Exception) {
logger.warn("Exception while creating initiated flow", exception) logger.warn("Exception while creating initiated flow", exception)
@ -408,7 +408,7 @@ class MultiThreadedStateMachineManager(
if (replyError != null) { if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
} }
} }
@ -431,7 +431,7 @@ class MultiThreadedStateMachineManager(
private fun <A> startInitiatedFlow( private fun <A> startInitiatedFlow(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
triggeringUnacknowledgedMessage: AcknowledgeHandle, initiatingMessageDeduplicationHandler: DeduplicationHandler,
peerSession: FlowSessionImpl, peerSession: FlowSessionImpl,
initiatedSessionId: SessionId, initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage, initiatingMessage: InitialSessionMessage,
@ -442,7 +442,7 @@ class MultiThreadedStateMachineManager(
val ourIdentity = getOurFirstIdentity() val ourIdentity = getOurFirstIdentity()
startFlowInternal( startFlowInternal(
InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity,
triggeringUnacknowledgedMessage, initiatingMessageDeduplicationHandler,
isStartIdempotent = false isStartIdempotent = false
) )
} }
@ -452,7 +452,7 @@ class MultiThreadedStateMachineManager(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
flowStart: FlowStart, flowStart: FlowStart,
ourIdentity: Party, ourIdentity: Party,
initialUnacknowledgedMessage: AcknowledgeHandle?, deduplicationHandler: DeduplicationHandler?,
isStartIdempotent: Boolean isStartIdempotent: Boolean
): CordaFuture<FlowStateMachine<A>> { ): CordaFuture<FlowStateMachine<A>> {
val flowId = StateMachineRunId.createRandom() val flowId = StateMachineRunId.createRandom()
@ -475,7 +475,7 @@ class MultiThreadedStateMachineManager(
val startedFuture = openFuture<Unit>() val startedFuture = openFuture<Unit>()
val initialState = StateMachineState( val initialState = StateMachineState(
checkpoint = initialCheckpoint, checkpoint = initialCheckpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = false, isAnyCheckpointPersisted = false,
@ -532,7 +532,7 @@ class MultiThreadedStateMachineManager(
checkpoint: Checkpoint, checkpoint: Checkpoint,
isAnyCheckpointPersisted: Boolean, isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean, isStartIdempotent: Boolean,
initialUnacknowledgedMessage: AcknowledgeHandle? initialDeduplicationHandler: DeduplicationHandler?
): Flow { ): Flow {
val flowState = checkpoint.flowState val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
@ -541,7 +541,7 @@ class MultiThreadedStateMachineManager(
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted, isAnyCheckpointPersisted = isAnyCheckpointPersisted,
@ -559,7 +559,7 @@ class MultiThreadedStateMachineManager(
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted, isAnyCheckpointPersisted = isAnyCheckpointPersisted,
@ -644,7 +644,7 @@ class MultiThreadedStateMachineManager(
totalSuccessFlows.inc() totalSuccessFlows.inc()
drainFlowEventQueue(flow) drainFlowEventQueue(flow)
// final sanity checks // final sanity checks
require(lastState.unacknowledgedMessages.isEmpty()) require(lastState.pendingDeduplicationHandlers.isEmpty())
require(lastState.isRemoved) require(lastState.isRemoved)
require(lastState.checkpoint.subFlowStack.size == 1) require(lastState.checkpoint.subFlowStack.size == 1)
sessionToFlow.none { it.value == flow.fiber.id } sessionToFlow.none { it.value == flow.fiber.id }
@ -676,7 +676,7 @@ class MultiThreadedStateMachineManager(
is Event.DoRemainingWork -> {} is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> { is Event.DeliverSessionMessage -> {
// Acknowledge the message so it doesn't leak in the broker. // Acknowledge the message so it doesn't leak in the broker.
event.acknowledgeHandle.acknowledge() event.deduplicationHandler.afterDatabaseTransaction()
when (event.sessionMessage.payload) { when (event.sessionMessage.payload) {
EndSessionMessage -> { EndSessionMessage -> {
logger.debug { "Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down" } logger.debug { "Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down" }

View File

@ -40,11 +40,10 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.interceptors.*
import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.services.statemachine.transitions.StateMachineConfiguration
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
@ -136,9 +135,9 @@ class SingleThreadedStateMachineManager(
} }
serviceHub.networkMapCache.nodeReady.then { serviceHub.networkMapCache.nodeReady.then {
resumeRestoredFlows(fibers) resumeRestoredFlows(fibers)
flowMessaging.start { receivedMessage, acknowledgeHandle -> flowMessaging.start { receivedMessage, deduplicationHandler ->
executor.execute { executor.execute {
onSessionMessage(receivedMessage, acknowledgeHandle) onSessionMessage(receivedMessage, deduplicationHandler)
} }
} }
} }
@ -203,14 +202,15 @@ class SingleThreadedStateMachineManager(
override fun <A> startFlow( override fun <A> startFlow(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
context: InvocationContext, context: InvocationContext,
ourIdentity: Party? ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>> { ): CordaFuture<FlowStateMachine<A>> {
return startFlowInternal( return startFlowInternal(
invocationContext = context, invocationContext = context,
flowLogic = flowLogic, flowLogic = flowLogic,
flowStart = FlowStart.Explicit, flowStart = FlowStart.Explicit,
ourIdentity = ourIdentity ?: getOurFirstIdentity(), ourIdentity = ourIdentity ?: getOurFirstIdentity(),
initialUnacknowledgedMessage = null, deduplicationHandler = deduplicationHandler,
isStartIdempotent = false isStartIdempotent = false
) )
} }
@ -320,7 +320,7 @@ class SingleThreadedStateMachineManager(
createFlowFromCheckpoint( createFlowFromCheckpoint(
id = id, id = id,
checkpoint = checkpoint, checkpoint = checkpoint,
initialUnacknowledgedMessage = null, initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true, isAnyCheckpointPersisted = true,
isStartIdempotent = false isStartIdempotent = false
) )
@ -333,32 +333,32 @@ class SingleThreadedStateMachineManager(
} }
} }
private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
val peer = message.peer val peer = message.peer
val sessionMessage = try { val sessionMessage = try {
message.data.deserialize<SessionMessage>() message.data.deserialize<SessionMessage>()
} catch (ex: Exception) { } catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from $peer") logger.error("Received corrupt SessionMessage data from $peer")
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
return return
} }
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender)
} }
} else { } else {
logger.error("Unknown peer $peer in $sessionMessage") logger.error("Unknown peer $peer in $sessionMessage")
} }
} }
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) {
try { try {
val recipientId = sessionMessage.recipientSessionId val recipientId = sessionMessage.recipientSessionId
val flowId = sessionToFlow[recipientId] val flowId = sessionToFlow[recipientId]
if (flowId == null) { if (flowId == null) {
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
if (sessionMessage.payload is EndSessionMessage) { if (sessionMessage.payload is EndSessionMessage) {
logger.debug { logger.debug {
"Got ${EndSessionMessage::class.java.simpleName} for " + "Got ${EndSessionMessage::class.java.simpleName} for " +
@ -369,7 +369,7 @@ class SingleThreadedStateMachineManager(
} }
} else { } else {
val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
} }
} catch (exception: Exception) { } catch (exception: Exception) {
logger.error("Exception while routing $sessionMessage", exception) logger.error("Exception while routing $sessionMessage", exception)
@ -377,7 +377,7 @@ class SingleThreadedStateMachineManager(
} }
} }
private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) {
fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage {
val errorId = secureRandom.nextLong() val errorId = secureRandom.nextLong()
val payload = RejectSessionMessage(message, errorId) val payload = RejectSessionMessage(message, errorId)
@ -396,7 +396,7 @@ class SingleThreadedStateMachineManager(
is InitiatedFlowFactory.Core -> senderPlatformVersion is InitiatedFlowFactory.Core -> senderPlatformVersion
is InitiatedFlowFactory.CorDapp -> null is InitiatedFlowFactory.CorDapp -> null
} }
startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo)
null null
} catch (exception: Exception) { } catch (exception: Exception) {
logger.warn("Exception while creating initiated flow", exception) logger.warn("Exception while creating initiated flow", exception)
@ -408,7 +408,7 @@ class SingleThreadedStateMachineManager(
if (replyError != null) { if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
acknowledgeHandle.acknowledge() deduplicationHandler.afterDatabaseTransaction()
} }
} }
@ -431,7 +431,7 @@ class SingleThreadedStateMachineManager(
private fun <A> startInitiatedFlow( private fun <A> startInitiatedFlow(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
triggeringUnacknowledgedMessage: AcknowledgeHandle, initiatingMessageDeduplicationHandler: DeduplicationHandler,
peerSession: FlowSessionImpl, peerSession: FlowSessionImpl,
initiatedSessionId: SessionId, initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage, initiatingMessage: InitialSessionMessage,
@ -442,7 +442,7 @@ class SingleThreadedStateMachineManager(
val ourIdentity = getOurFirstIdentity() val ourIdentity = getOurFirstIdentity()
startFlowInternal( startFlowInternal(
InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity,
triggeringUnacknowledgedMessage, initiatingMessageDeduplicationHandler,
isStartIdempotent = false isStartIdempotent = false
) )
} }
@ -452,7 +452,7 @@ class SingleThreadedStateMachineManager(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
flowStart: FlowStart, flowStart: FlowStart,
ourIdentity: Party, ourIdentity: Party,
initialUnacknowledgedMessage: AcknowledgeHandle?, deduplicationHandler: DeduplicationHandler?,
isStartIdempotent: Boolean isStartIdempotent: Boolean
): CordaFuture<FlowStateMachine<A>> { ): CordaFuture<FlowStateMachine<A>> {
val flowId = StateMachineRunId.createRandom() val flowId = StateMachineRunId.createRandom()
@ -475,7 +475,7 @@ class SingleThreadedStateMachineManager(
val startedFuture = openFuture<Unit>() val startedFuture = openFuture<Unit>()
val initialState = StateMachineState( val initialState = StateMachineState(
checkpoint = initialCheckpoint, checkpoint = initialCheckpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = false, isAnyCheckpointPersisted = false,
@ -532,7 +532,7 @@ class SingleThreadedStateMachineManager(
checkpoint: Checkpoint, checkpoint: Checkpoint,
isAnyCheckpointPersisted: Boolean, isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean, isStartIdempotent: Boolean,
initialUnacknowledgedMessage: AcknowledgeHandle? initialDeduplicationHandler: DeduplicationHandler?
): Flow { ): Flow {
val flowState = checkpoint.flowState val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
@ -541,7 +541,7 @@ class SingleThreadedStateMachineManager(
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted, isAnyCheckpointPersisted = isAnyCheckpointPersisted,
@ -559,7 +559,7 @@ class SingleThreadedStateMachineManager(
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false, isFlowResumed = false,
isTransactionTracked = false, isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted, isAnyCheckpointPersisted = isAnyCheckpointPersisted,
@ -648,7 +648,7 @@ class SingleThreadedStateMachineManager(
) { ) {
drainFlowEventQueue(flow) drainFlowEventQueue(flow)
// final sanity checks // final sanity checks
require(lastState.unacknowledgedMessages.isEmpty()) require(lastState.pendingDeduplicationHandlers.isEmpty())
require(lastState.isRemoved) require(lastState.isRemoved)
require(lastState.checkpoint.subFlowStack.size == 1) require(lastState.checkpoint.subFlowStack.size == 1)
sessionToFlow.none { it.value == flow.fiber.id } sessionToFlow.none { it.value == flow.fiber.id }
@ -679,7 +679,7 @@ class SingleThreadedStateMachineManager(
is Event.DoRemainingWork -> {} is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> { is Event.DeliverSessionMessage -> {
// Acknowledge the message so it doesn't leak in the broker. // Acknowledge the message so it doesn't leak in the broker.
event.acknowledgeHandle.acknowledge() event.deduplicationHandler.afterDatabaseTransaction()
when (event.sessionMessage.payload) { when (event.sessionMessage.payload) {
EndSessionMessage -> { EndSessionMessage -> {
logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" } logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" }

View File

@ -18,6 +18,7 @@ import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
import rx.Observable import rx.Observable
/** /**
@ -58,11 +59,14 @@ interface StateMachineManager {
* *
* @param flowLogic The flow's code. * @param flowLogic The flow's code.
* @param context The context of the flow. * @param context The context of the flow.
* @param ourIdentity The identity to use for the flow.
* @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler].
*/ */
fun <A> startFlow( fun <A> startFlow(
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
context: InvocationContext, context: InvocationContext,
ourIdentity: Party? = null ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>> ): CordaFuture<FlowStateMachine<A>>
/** /**

View File

@ -17,7 +17,7 @@ import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.DeduplicationHandler
/** /**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
@ -25,7 +25,7 @@ import net.corda.node.services.messaging.AcknowledgeHandle
* *
* @param checkpoint the persisted part of the state. * @param checkpoint the persisted part of the state.
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user. * @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
* @param unacknowledgedMessages the list of currently unacknowledged messages. * @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used * @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
* to make [Event.DoRemainingWork] idempotent. * to make [Event.DoRemainingWork] idempotent.
* @param isTransactionTracked true if a ledger transaction has been tracked as part of a * @param isTransactionTracked true if a ledger transaction has been tracked as part of a
@ -42,7 +42,7 @@ import net.corda.node.services.messaging.AcknowledgeHandle
data class StateMachineState( data class StateMachineState(
val checkpoint: Checkpoint, val checkpoint: Checkpoint,
val flowLogic: FlowLogic<*>, val flowLogic: FlowLogic<*>,
val unacknowledgedMessages: List<AcknowledgeHandle>, val pendingDeduplicationHandlers: List<DeduplicationHandler>,
val isFlowResumed: Boolean, val isFlowResumed: Boolean,
val isTransactionTracked: Boolean, val isTransactionTracked: Boolean,
val isAnyCheckpointPersisted: Boolean, val isAnyCheckpointPersisted: Boolean,

View File

@ -34,12 +34,12 @@ class DeliverSessionMessageTransition(
) : Transition { ) : Transition {
override fun transition(): TransitionResult { override fun transition(): TransitionResult {
return builder { return builder {
// Add the AcknowledgeHandle to the unacknowledged messages ASAP so in case an error happens we still know // Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know
// about the message. Note that in case of an error during deliver this message *will be acked*. // about the message. Note that in case of an error during deliver this message *will be acked*.
// For example if the session corresponding to the message is not found the message is still acked to free // For example if the session corresponding to the message is not found the message is still acked to free
// up the broker but the flow will error. // up the broker but the flow will error.
currentState = currentState.copy( currentState = currentState.copy(
unacknowledgedMessages = currentState.unacknowledgedMessages + event.acknowledgeHandle pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler
) )
// Check whether we have a session corresponding to the message. // Check whether we have a session corresponding to the message.
val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId] val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId]
@ -163,12 +163,12 @@ class DeliverSessionMessageTransition(
actions.addAll(arrayOf( actions.addAll(arrayOf(
Action.CreateTransaction, Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint), Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction, Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages) Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
)) ))
currentState = currentState.copy( currentState = currentState.copy(
unacknowledgedMessages = emptyList(), pendingDeduplicationHandlers = emptyList(),
isAnyCheckpointPersisted = true isAnyCheckpointPersisted = true
) )
} }

View File

@ -71,14 +71,14 @@ class ErrorFlowTransition(
actions.add(Action.RemoveCheckpoint(context.id)) actions.add(Action.RemoveCheckpoint(context.id))
} }
actions.addAll(arrayOf( actions.addAll(arrayOf(
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction, Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages), Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys)
)) ))
currentState = currentState.copy( currentState = currentState.copy(
unacknowledgedMessages = emptyList(), pendingDeduplicationHandlers = emptyList(),
isRemoved = true isRemoved = true
) )

View File

@ -151,14 +151,14 @@ class TopLevelTransition(
} else { } else {
actions.addAll(arrayOf( actions.addAll(arrayOf(
Action.PersistCheckpoint(context.id, newCheckpoint), Action.PersistCheckpoint(context.id, newCheckpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction, Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages), Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
Action.ScheduleEvent(Event.DoRemainingWork) Action.ScheduleEvent(Event.DoRemainingWork)
)) ))
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = newCheckpoint, checkpoint = newCheckpoint,
unacknowledgedMessages = emptyList(), pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false, isFlowResumed = false,
isAnyCheckpointPersisted = true isAnyCheckpointPersisted = true
) )
@ -172,12 +172,12 @@ class TopLevelTransition(
val checkpoint = currentState.checkpoint val checkpoint = currentState.checkpoint
when (checkpoint.errorState) { when (checkpoint.errorState) {
ErrorState.Clean -> { ErrorState.Clean -> {
val unacknowledgedMessages = currentState.unacknowledgedMessages val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = checkpoint.copy( checkpoint = checkpoint.copy(
numberOfSuspends = checkpoint.numberOfSuspends + 1 numberOfSuspends = checkpoint.numberOfSuspends + 1
), ),
unacknowledgedMessages = emptyList(), pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false, isFlowResumed = false,
isRemoved = true isRemoved = true
) )
@ -186,9 +186,9 @@ class TopLevelTransition(
actions.add(Action.RemoveCheckpoint(context.id)) actions.add(Action.RemoveCheckpoint(context.id))
} }
actions.addAll(arrayOf( actions.addAll(arrayOf(
Action.PersistDeduplicationIds(unacknowledgedMessages), Action.PersistDeduplicationFacts(pendingDeduplicationHandlers),
Action.CommitTransaction, Action.CommitTransaction,
Action.AcknowledgeMessages(unacknowledgedMessages), Action.AcknowledgeMessages(pendingDeduplicationHandlers),
Action.RemoveSessionBindings(allSourceSessionIds), Action.RemoveSessionBindings(allSourceSessionIds),
Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState) Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState)
)) ))

View File

@ -78,12 +78,12 @@ class UnstartedFlowTransition(
actions.addAll(arrayOf( actions.addAll(arrayOf(
Action.CreateTransaction, Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint), Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction, Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages) Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
)) ))
currentState = currentState.copy( currentState = currentState.copy(
unacknowledgedMessages = emptyList(), pendingDeduplicationHandlers = emptyList(),
isAnyCheckpointPersisted = true isAnyCheckpointPersisted = true
) )
} }

View File

@ -4,7 +4,7 @@ keyStorePassword = "cordacadevpass"
trustStorePassword = "trustpass" trustStorePassword = "trustpass"
dataSourceProperties = { dataSourceProperties = {
dataSourceClassName = org.h2.jdbcx.JdbcDataSource dataSourceClassName = org.h2.jdbcx.JdbcDataSource
dataSource.url = "jdbc:h2:file:"${baseDirectory}"/persistence;DB_CLOSE_ON_EXIT=FALSE;LOCK_TIMEOUT=10000;WRITE_DELAY=100;AUTO_SERVER_PORT="${h2port} dataSource.url = "jdbc:h2:file:"${baseDirectory}"/persistence;DB_CLOSE_ON_EXIT=FALSE;LOCK_TIMEOUT=10000;WRITE_DELAY=0;AUTO_SERVER_PORT="${h2port}
dataSource.user = sa dataSource.user = sa
dataSource.password = "" dataSource.password = ""
} }

View File

@ -48,7 +48,7 @@ class NodeSchedulerServiceTest {
}.whenever(it).transaction(any()) }.whenever(it).transaction(any())
} }
private val flowStarter = rigorousMock<FlowStarter>().also { private val flowStarter = rigorousMock<FlowStarter>().also {
doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any()) doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any(), any())
} }
private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also { private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also {
doReturn(false).whenever(it).isEnabled() doReturn(false).whenever(it).isEnabled()
@ -111,7 +111,7 @@ class NodeSchedulerServiceTest {
private fun assertStarted(event: Event) { private fun assertStarted(event: Event) {
// Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow: // Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow:
verify(flowStarter, timeout(5000)).startFlow(same(event.flowLogic)!!, any()) verify(flowStarter, timeout(5000)).startFlow(same(event.flowLogic)!!, any(), any())
} }
@Test @Test

View File

@ -217,8 +217,8 @@ class ArtemisMessagingTest {
try { try {
val messagingClient2 = createMessagingClient() val messagingClient2 = createMessagingClient()
messagingClient2.addMessageHandler(TOPIC) { msg, _, handle -> messagingClient2.addMessageHandler(TOPIC) { msg, _, handle ->
database.transaction { handle.persistDeduplicationId() } database.transaction { handle.insideDatabaseTransaction() }
handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
receivedMessages.add(msg) receivedMessages.add(msg)
} }
startNodeMessagingClient() startNodeMessagingClient()
@ -252,8 +252,8 @@ class ArtemisMessagingTest {
val messagingClient3 = createMessagingClient() val messagingClient3 = createMessagingClient()
messagingClient3.addMessageHandler(TOPIC) { msg, _, handle -> messagingClient3.addMessageHandler(TOPIC) { msg, _, handle ->
database.transaction { handle.persistDeduplicationId() } database.transaction { handle.insideDatabaseTransaction() }
handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
receivedMessages.add(msg) receivedMessages.add(msg)
} }
startNodeMessagingClient() startNodeMessagingClient()
@ -281,8 +281,8 @@ class ArtemisMessagingTest {
val messagingClient = createMessagingClient(platformVersion = platformVersion) val messagingClient = createMessagingClient(platformVersion = platformVersion)
messagingClient.addMessageHandler(TOPIC) { message, _, handle -> messagingClient.addMessageHandler(TOPIC) { message, _, handle ->
database.transaction { handle.persistDeduplicationId() } database.transaction { handle.insideDatabaseTransaction() }
handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
receivedMessages.add(message) receivedMessages.add(message)
} }
startNodeMessagingClient() startNodeMessagingClient()

View File

@ -480,13 +480,7 @@ class InMemoryMessagingNetwork private constructor(
database.transaction { database.transaction {
for (handler in deliverTo) { for (handler in deliverTo) {
try { try {
val acknowledgeHandle = object : AcknowledgeHandle { handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler())
override fun acknowledge() {
}
override fun persistDeduplicationId() {
}
}
handler.callback(transfer.toReceivedMessage(), handler, acknowledgeHandle)
} catch (e: Exception) { } catch (e: Exception) {
log.error("Caught exception in handler for $this/${handler.topicSession}", e) log.error("Caught exception in handler for $this/${handler.topicSession}", e)
} }
@ -510,5 +504,12 @@ class InMemoryMessagingNetwork private constructor(
message.debugTimestamp, message.debugTimestamp,
sender.name) sender.name)
} }
private class DummyDeduplicationHandler : DeduplicationHandler {
override fun afterDatabaseTransaction() {
}
override fun insideDatabaseTransaction() {
}
}
} }