CORDA-1475 CORDA-1465 Allow flows to retry from last checkpoint ()

This commit is contained in:
Rick Parker 2018-05-25 13:26:00 +01:00 committed by GitHub
parent 7cbc316b9d
commit 59fdb3df67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 1843 additions and 469 deletions

@ -41,4 +41,5 @@ interface FlowStateMachine<FLOWRETURN> {
val resultFuture: CordaFuture<FLOWRETURN>
val context: InvocationContext
val ourIdentity: Party
val ourSenderUUID: String?
}

@ -5,7 +5,6 @@ import net.corda.core.schemas.MappedSchema
import net.corda.core.utilities.contextLogger
import rx.Observable
import rx.Subscriber
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.Closeable
import java.sql.Connection
@ -67,9 +66,7 @@ class CordaPersistence(
}
val entityManagerFactory get() = hibernateConfig.sessionFactoryForRegisteredSchemas
data class Boundary(val txId: UUID)
internal val transactionBoundaries = PublishSubject.create<Boundary>().toSerialized()
data class Boundary(val txId: UUID, val success: Boolean)
init {
// Found a unit test that was forgetting to close the database transactions. When you close() on the top level
@ -186,15 +183,19 @@ class CordaPersistence(
*
* For examples, see the call hierarchy of this function.
*/
fun <T : Any> rx.Observer<T>.bufferUntilDatabaseCommit(): rx.Observer<T> {
val currentTxId = contextTransaction.id
val databaseTxBoundary: Observable<CordaPersistence.Boundary> = contextDatabase.transactionBoundaries.first { it.txId == currentTxId }
fun <T : Any> rx.Observer<T>.bufferUntilDatabaseCommit(propagateRollbackAsError: Boolean = false): rx.Observer<T> {
val currentTx = contextTransaction
val subject = UnicastSubject.create<T>()
val databaseTxBoundary: Observable<CordaPersistence.Boundary> = currentTx.boundary.filter { it.success }
if (propagateRollbackAsError) {
currentTx.boundary.filter { !it.success }.subscribe { this.onError(DatabaseTransactionRolledBackException(it.txId)) }
}
subject.delaySubscription(databaseTxBoundary).subscribe(this)
databaseTxBoundary.doOnCompleted { subject.onCompleted() }
return subject
}
class DatabaseTransactionRolledBackException(txId: UUID) : Exception("Database transaction $txId was rolled back")
// A subscriber that delegates to multiple others, wrapping a database transaction around the combination.
private class DatabaseTransactionWrappingSubscriber<U>(private val db: CordaPersistence?) : Subscriber<U>() {
// Some unsubscribes happen inside onNext() so need something that supports concurrent modification.

@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.persistence
import co.paralleluniverse.strands.Strand
import org.hibernate.Session
import org.hibernate.Transaction
import rx.subjects.PublishSubject
import java.sql.Connection
import java.util.*
@ -35,11 +36,16 @@ class DatabaseTransaction(
val session: Session by sessionDelegate
private lateinit var hibernateTransaction: Transaction
internal val boundary = PublishSubject.create<CordaPersistence.Boundary>()
private var committed = false
fun commit() {
if (sessionDelegate.isInitialized()) {
hibernateTransaction.commit()
}
connection.commit()
committed = true
}
fun rollback() {
@ -58,7 +64,15 @@ class DatabaseTransaction(
connection.close()
contextTransactionOrNull = outerTransaction
if (outerTransaction == null) {
database.transactionBoundaries.onNext(CordaPersistence.Boundary(id))
boundary.onNext(CordaPersistence.Boundary(id, committed))
}
}
fun onCommit(callback: () -> Unit) {
boundary.filter { it.success }.subscribe { callback() }
}
fun onRollback(callback: () -> Unit) {
boundary.filter { !it.success }.subscribe { callback() }
}
}

@ -0,0 +1,158 @@
package net.corda.node.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.testing.core.singleIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.driver.internal.RandomFree
import net.corda.testing.node.User
import org.junit.Before
import org.junit.Test
import java.lang.management.ManagementFactory
import java.sql.SQLException
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
class FlowRetryTest {
@Before
fun resetCounters() {
InitiatorFlow.seen.clear()
InitiatedFlow.seen.clear()
}
@Test
fun `flows continue despite errors`() {
val numSessions = 2
val numIterations = 10
val user = User("mark", "dadada", setOf(Permissions.startFlow<InitiatorFlow>()))
val result: Any? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(),
portAllocation = RandomFree)) {
val nodeAHandle = startNode(rpcUsers = listOf(user)).getOrThrow()
val nodeBHandle = startNode(rpcUsers = listOf(user)).getOrThrow()
val result = CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorFlow, numSessions, numIterations, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow()
}
result
}
assertNotNull(result)
assertEquals("$numSessions:$numIterations", result)
}
}
fun isQuasarAgentSpecified(): Boolean {
val jvmArgs = ManagementFactory.getRuntimeMXBean().inputArguments
return jvmArgs.any { it.startsWith("-javaagent:") && it.contains("quasar") }
}
class ExceptionToCauseRetry : SQLException("deadlock")
@StartableByRPC
@InitiatingFlow
class InitiatorFlow(private val sessionsCount: Int, private val iterationsCount: Int, private val other: Party) : FlowLogic<Any>() {
companion object {
object FIRST_STEP : ProgressTracker.Step("Step one")
fun tracker() = ProgressTracker(FIRST_STEP)
val seen = Collections.synchronizedSet(HashSet<Visited>())
fun visit(sessionNum: Int, iterationNum: Int, step: Step) {
val visited = Visited(sessionNum, iterationNum, step)
if (visited !in seen) {
seen += visited
throw ExceptionToCauseRetry()
}
}
}
override val progressTracker = tracker()
@Suspendable
override fun call(): Any {
progressTracker.currentStep = FIRST_STEP
var received: Any? = null
visit(-1, -1, Step.First)
for (sessionNum in 1..sessionsCount) {
visit(sessionNum, -1, Step.BeforeInitiate)
val session = initiateFlow(other)
visit(sessionNum, -1, Step.AfterInitiate)
session.send(SessionInfo(sessionNum, iterationsCount))
visit(sessionNum, -1, Step.AfterInitiateSendReceive)
for (iteration in 1..iterationsCount) {
visit(sessionNum, iteration, Step.BeforeSend)
logger.info("A Sending $sessionNum:$iteration")
session.send("$sessionNum:$iteration")
visit(sessionNum, iteration, Step.AfterSend)
received = session.receive<Any>().unwrap { it }
visit(sessionNum, iteration, Step.AfterReceive)
logger.info("A Got $sessionNum:$iteration")
}
doSleep()
}
return received!!
}
// This non-flow-friendly sleep triggered a bug with session end messages and non-retryable checkpoints.
private fun doSleep() {
Thread.sleep(2000)
}
}
@InitiatedBy(InitiatorFlow::class)
class InitiatedFlow(val session: FlowSession) : FlowLogic<Any>() {
companion object {
object FIRST_STEP : ProgressTracker.Step("Step one")
fun tracker() = ProgressTracker(FIRST_STEP)
val seen = Collections.synchronizedSet(HashSet<Visited>())
fun visit(sessionNum: Int, iterationNum: Int, step: Step) {
val visited = Visited(sessionNum, iterationNum, step)
if (visited !in seen) {
seen += visited
throw ExceptionToCauseRetry()
}
}
}
override val progressTracker = tracker()
@Suspendable
override fun call() {
progressTracker.currentStep = FIRST_STEP
visit(-1, -1, Step.AfterInitiate)
val sessionInfo = session.receive<SessionInfo>().unwrap { it }
visit(sessionInfo.sessionNum, -1, Step.AfterInitiateSendReceive)
for (iteration in 1..sessionInfo.iterationsCount) {
visit(sessionInfo.sessionNum, iteration, Step.BeforeReceive)
val got = session.receive<Any>().unwrap { it }
visit(sessionInfo.sessionNum, iteration, Step.AfterReceive)
logger.info("B Got $got")
logger.info("B Sending $got")
visit(sessionInfo.sessionNum, iteration, Step.BeforeSend)
session.send(got)
visit(sessionInfo.sessionNum, iteration, Step.AfterSend)
}
}
}
@CordaSerializable
data class SessionInfo(val sessionNum: Int, val iterationsCount: Int)
enum class Step { First, BeforeInitiate, AfterInitiate, AfterInitiateSendReceive, BeforeSend, AfterSend, BeforeReceive, AfterReceive }
data class Visited(val sessionNum: Int, val iterationNum: Int, val step: Step)

@ -25,6 +25,7 @@ import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.testMessage.ScheduledState
import net.corda.testMessage.SpentState
import net.corda.testing.contracts.DummyContract
@ -96,7 +97,7 @@ class ScheduledFlowIntegrationTests {
val aliceClient = CordaRPCClient(alice.rpcAddress).start(rpcUser.username, rpcUser.password)
val bobClient = CordaRPCClient(bob.rpcAddress).start(rpcUser.username, rpcUser.password)
val scheduledFor = Instant.now().plusSeconds(20)
val scheduledFor = Instant.now().plusSeconds(10)
val initialiseFutures = mutableListOf<CordaFuture<*>>()
for (i in 0 until N) {
initialiseFutures.add(aliceClient.proxy.startFlow(::InsertInitialStateFlow, bob.nodeInfo.legalIdentities.first(), defaultNotaryIdentity, i, scheduledFor).returnValue)
@ -111,6 +112,9 @@ class ScheduledFlowIntegrationTests {
}
spendAttemptFutures.getOrThrowAll()
// TODO: the queries below are not atomic so we need to allow enough time for the scheduler to finish. Would be better to query scheduler.
Thread.sleep(20.seconds.toMillis())
val aliceStates = aliceClient.proxy.vaultQuery(ScheduledState::class.java).states.filter { it.state.data.processed }
val aliceSpentStates = aliceClient.proxy.vaultQuery(SpentState::class.java).states

@ -985,8 +985,37 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
}
internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture<FlowStateMachine<T>> {
return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler)
override fun <T> startFlow(event: ExternalEvent.ExternalStartFlowEvent<T>): CordaFuture<FlowStateMachine<T>> {
smm.deliverExternalEvent(event)
return event.future
}
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> {
val startFlowEvent = object : ExternalEvent.ExternalStartFlowEvent<T>, DeduplicationHandler {
override fun insideDatabaseTransaction() {}
override fun afterDatabaseTransaction() {}
override val externalCause: ExternalEvent
get() = this
override val deduplicationHandler: DeduplicationHandler
get() = this
override val flowLogic: FlowLogic<T>
get() = logic
override val context: InvocationContext
get() = context
override fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<T>>) {
_future.captureLater(flowFuture)
}
private val _future = openFuture<FlowStateMachine<T>>()
override val future: CordaFuture<FlowStateMachine<T>>
get() = _future
}
return startFlow(startFlowEvent)
}
override fun <T> invokeFlowAsync(

@ -20,6 +20,12 @@ interface CheckpointStorage {
*/
fun removeCheckpoint(id: StateMachineRunId): Boolean
/**
* Load an existing checkpoint from the store.
* @return the checkpoint, still in serialized form, or null if not found.
*/
fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>?
/**
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
* underlying database connection is closed, so any processing should happen before it is closed.

@ -19,9 +19,9 @@ import net.corda.core.utilities.contextLogger
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.cordapp.CordappProviderInternal
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.NetworkMapUpdater
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.nodeapi.internal.persistence.CordaPersistence
@ -134,11 +134,17 @@ interface ServiceHubInternal : ServiceHub {
interface FlowStarter {
/**
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
* Starts an already constructed flow. Note that you must be on the server thread to call this method. This method
* just synthesizes an [ExternalEvent.ExternalStartFlowEvent] and calls the method below.
* @param context indicates who started the flow, see: [InvocationContext].
* @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler]
*/
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture<FlowStateMachine<T>>
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>>
/**
* Starts a flow as described by an [ExternalEvent.ExternalStartFlowEvent]. If a transient error
* occurs during invocation, it will re-attempt to start the flow.
*/
fun <T> startFlow(event: ExternalEvent.ExternalStartFlowEvent<T>): CordaFuture<FlowStateMachine<T>>
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.

@ -2,6 +2,7 @@ package net.corda.node.services.events
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.context.InvocationOrigin
import net.corda.core.contracts.SchedulableState
@ -10,11 +11,9 @@ import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.join
import net.corda.core.internal.until
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.node.ServicesForResolution
import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SingletonSerializeAsToken
@ -26,8 +25,10 @@ import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.NodePropertiesStore
import net.corda.node.services.api.SchedulerService
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.contextTransaction
import org.apache.activemq.artemis.utils.ReusableLatch
import org.apache.mina.util.ConcurrentHashSet
import org.slf4j.Logger
@ -156,29 +157,31 @@ class NodeSchedulerService(private val clock: CordaClock,
override fun scheduleStateActivity(action: ScheduledStateRef) {
log.trace { "Schedule $action" }
if (!schedulerRepo.merge(action)) {
// Only increase the number of unfinished schedules if the state didn't already exist on the queue
unfinishedSchedules.countUp()
}
mutex.locked {
if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) {
// We are earliest
rescheduleWakeUp()
} else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) {
// We were earliest but might not be any more
rescheduleWakeUp()
// Only increase the number of unfinished schedules if the state didn't already exist on the queue
val countUp = !schedulerRepo.merge(action)
contextTransaction.onCommit {
if (countUp) unfinishedSchedules.countUp()
mutex.locked {
if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) {
// We are earliest
rescheduleWakeUp()
} else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) {
// We were earliest but might not be any more
rescheduleWakeUp()
}
}
}
}
override fun unscheduleStateActivity(ref: StateRef) {
log.trace { "Unschedule $ref" }
if (startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref)) {
unfinishedSchedules.countDown()
}
mutex.locked {
if (nextScheduledAction?.ref == ref) {
rescheduleWakeUp()
val countDown = startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref)
contextTransaction.onCommit {
if (countDown) unfinishedSchedules.countDown()
mutex.locked {
if (nextScheduledAction?.ref == ref) {
rescheduleWakeUp()
}
}
}
}
@ -227,7 +230,12 @@ class NodeSchedulerService(private val clock: CordaClock,
schedulerTimerExecutor.join()
}
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler {
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef, override val flowLogic: FlowLogic<Any?>, override val context: InvocationContext) : DeduplicationHandler, ExternalEvent.ExternalStartFlowEvent<Any?> {
override val externalCause: ExternalEvent
get() = this
override val deduplicationHandler: FlowStartDeduplicationHandler
get() = this
override fun insideDatabaseTransaction() {
schedulerRepo.delete(scheduledState.ref)
}
@ -239,6 +247,18 @@ class NodeSchedulerService(private val clock: CordaClock,
override fun toString(): String {
return "${javaClass.simpleName}($scheduledState)"
}
override fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<Any?>>) {
_future.captureLater(flowFuture)
val future = _future.flatMap { it.resultFuture }
future.then {
unfinishedSchedules.countDown()
}
}
private val _future = openFuture<FlowStateMachine<Any?>>()
override val future: CordaFuture<FlowStateMachine<Any?>>
get() = _future
}
private fun onTimeReached(scheduledState: ScheduledStateRef) {
@ -250,11 +270,8 @@ class NodeSchedulerService(private val clock: CordaClock,
flowName = scheduledFlow.javaClass.name
// TODO refactor the scheduler to store and propagate the original invocation context
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState)
val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture }
future.then {
unfinishedSchedules.countDown()
}
val startFlowEvent = FlowStartDeduplicationHandler(scheduledState, scheduledFlow, context)
flowStarter.startFlow(startFlowEvent)
}
}
} catch (e: Exception) {

@ -10,6 +10,8 @@ import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import java.time.Instant
import javax.annotation.concurrent.ThreadSafe
@ -25,6 +27,12 @@ import javax.annotation.concurrent.ThreadSafe
*/
@ThreadSafe
interface MessagingService {
/**
* A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence
* number for message de-duplication at the recipient.
*/
val ourSenderUUID: String
/**
* The provided function will be invoked for each received message whose topic and session matches. The callback
* will run on the main server thread provided when the messaging service is constructed, and a database
@ -93,11 +101,12 @@ interface MessagingService {
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topicSession identifier for the topic and session the message is sent to.
* @param additionalProperties optional additional message headers.
* @param topic identifier for the topic the message is sent to.
* @param data the payload for the message.
* @param deduplicationId optional message deduplication ID including sender identifier.
* @param additionalHeaders optional additional message headers.
*/
fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map<String, String> = emptyMap()): Message
fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()): Message
/** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -106,7 +115,7 @@ interface MessagingService {
val myAddress: SingleMessageRecipient
}
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
interface MessageHandlerRegistration
@ -152,15 +161,17 @@ object TopicStringValidator {
}
/**
* This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done
* This handler is used to implement exactly-once delivery of an external event on top of an at-least-once delivery. This is done
* using two hooks that are called from the event processor, one called from the database transaction committing the
* side-effect caused by the event, and another one called after the transaction has committed successfully.
* side-effect caused by the external event, and another one called after the transaction has committed successfully.
*
* For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later
* deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries.
*
* We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the
* to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping.
*
* It holds a reference back to the causing external event.
*/
interface DeduplicationHandler {
/**
@ -174,6 +185,11 @@ interface DeduplicationHandler {
* cleanup/acknowledgement/stopping of retries.
*/
fun afterDatabaseTransaction()
/**
* The external event for which we are trying to reduce from at-least-once delivery to exactly-once.
*/
val externalCause: ExternalEvent
}
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit

@ -8,7 +8,6 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import java.io.Serializable
import java.time.Instant
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import javax.persistence.Column
import javax.persistence.Entity
@ -18,8 +17,6 @@ import javax.persistence.Id
* Encapsulate the de-duplication logic.
*/
class P2PMessageDeduplicator(private val database: CordaPersistence) {
val ourSenderUUID = UUID.randomUUID().toString()
// A temporary in-memory set of deduplication IDs and associated high water mark details.
// When we receive a message we don't persist the ID immediately,
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may

@ -23,6 +23,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multipl
import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.PersistentMap
import net.corda.nodeapi.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport
@ -124,7 +126,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
)
}
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topic#${String(data.bytes)}"
}
@ -158,6 +160,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
override val ourSenderUUID = UUID.randomUUID().toString()
private val messageRedeliveryDelaySeconds = config.p2pMessagingRetry.messageRedeliveryDelay.seconds
private val state = ThreadBox(InnerState())
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
@ -227,7 +231,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
producer!!,
versionInfo,
this@P2PMessagingClient,
ourSenderUUID = deduplicator.ourSenderUUID
ourSenderUUID = ourSenderUUID
)
registerBridgeControl(bridgeSession!!, inboxes.toList())
@ -435,18 +439,23 @@ class P2PMessagingClient(val config: NodeConfiguration,
}
}
inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler {
private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent
get() = this
override val deduplicationHandler: MessageDeduplicationHandler
get() = this
override fun insideDatabaseTransaction() {
deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId)
deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId)
}
override fun afterDatabaseTransaction() {
deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId)
deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId)
messagingExecutor!!.acknowledge(artemisMessage)
}
override fun toString(): String {
return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})"
return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})"
}
}
@ -610,8 +619,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
handlers.remove(registration.topic)
}
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders)
}
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {

@ -10,9 +10,9 @@ import net.corda.nodeapi.internal.persistence.currentDBSession
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.Serializable
import java.util.*
import java.util.stream.Stream
import java.io.Serializable
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
@ -53,6 +53,11 @@ class DBCheckpointStorage : CheckpointStorage {
return session.createQuery(delete).executeUpdate() > 0
}
override fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>? {
val bytes = currentDBSession().get(DBCheckpoint::class.java, id.uuid.toString())?.checkpoint ?: return null
return SerializedBytes<Checkpoint>(bytes)
}
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> {
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)

@ -21,7 +21,6 @@ import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
import rx.Observable
import rx.subjects.PublishSubject
import java.io.Serializable
import java.util.*
import javax.persistence.*
// cache value type to just store the immutable bits of a signed transaction plus conversion helpers
@ -71,11 +70,11 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
// to the memory pressure at all here.
private const val transactionSignatureOverheadEstimate = 1024
private fun weighTx(tx: Optional<TxCacheValue>): Int {
if (!tx.isPresent) {
private fun weighTx(tx: AppendOnlyPersistentMapBase.Transactional<TxCacheValue>): Int {
val actTx = tx.valueWithoutIsolation
if (actTx == null) {
return 0
}
val actTx = tx.get()
return actTx.second.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.first.size
}
}

@ -24,7 +24,7 @@ sealed class Action {
data class SendInitial(
val party: Party,
val initialise: InitialSessionMessage,
val deduplicationId: DeduplicationId
val deduplicationId: SenderDeduplicationId
) : Action()
/**
@ -33,7 +33,7 @@ sealed class Action {
data class SendExisting(
val peerParty: Party,
val message: ExistingSessionMessage,
val deduplicationId: DeduplicationId
val deduplicationId: SenderDeduplicationId
) : Action()
/**
@ -62,7 +62,8 @@ sealed class Action {
*/
data class PropagateErrors(
val errorMessages: List<ErrorSessionMessage>,
val sessions: List<SessionState.Initiated>
val sessions: List<SessionState.Initiated>,
val senderUUID: String?
) : Action()
/**
@ -129,6 +130,11 @@ sealed class Action {
* Release soft locks associated with given ID (currently the flow ID).
*/
data class ReleaseSoftLocks(val uuid: UUID?) : Action()
/**
* Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details.
*/
data class RetryFlowFromSafePoint(val currentState: StateMachineState) : Action()
}
/**

@ -73,6 +73,7 @@ class ActionExecutorImpl(
is Action.CommitTransaction -> executeCommitTransaction()
is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action)
is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action)
is Action.RetryFlowFromSafePoint -> executeRetryFlowFromSafePoint(action)
}
}
@ -125,7 +126,7 @@ class ActionExecutorImpl(
@Suspendable
private fun executePropagateErrors(action: Action.PropagateErrors) {
action.errorMessages.forEach { (exception) ->
log.debug("Propagating error", exception)
log.warn("Propagating error", exception)
}
for (sessionState in action.sessions) {
// We cannot propagate if the session isn't live.
@ -137,7 +138,7 @@ class ActionExecutorImpl(
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))
}
}
}
@ -226,6 +227,10 @@ class ActionExecutorImpl(
)
}
private fun executeRetryFlowFromSafePoint(action: Action.RetryFlowFromSafePoint) {
stateMachineManager.retryFlowFromSafePoint(action.currentState)
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
}

@ -45,3 +45,9 @@ data class DeduplicationId(val toString: String) {
}
}
}
/**
* Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be
* null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID.
*/
data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?)

@ -31,9 +31,9 @@ sealed class Event {
*/
data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage,
val deduplicationHandler: DeduplicationHandler,
override val deduplicationHandler: DeduplicationHandler,
val sender: Party
) : Event()
) : Event(), GeneratedByExternalEvent
/**
* Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error.
@ -133,4 +133,19 @@ sealed class Event {
* @param returnValue the result of the operation.
*/
data class AsyncOperationCompletion(val returnValue: Any?) : Event()
/**
* Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details.
*/
object RetryFlowFromSafePoint : Event() {
override fun toString() = "RetryFlowFromSafePoint"
}
/**
* Indicates that an event was generated by an external event and that external event needs to be replayed if we retry the flow,
* even if it has not yet been processed and placed on the pending de-duplication handlers list.
*/
interface GeneratedByExternalEvent {
val deduplicationHandler: DeduplicationHandler
}
}

@ -9,10 +9,15 @@ interface FlowHospital {
/**
* The flow running in [flowFiber] has errored.
*/
fun flowErrored(flowFiber: FlowFiber)
fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>)
/**
* The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume.
*/
fun flowCleaned(flowFiber: FlowFiber)
/**
* The flow has been removed from the state machine.
*/
fun flowRemoved(flowFiber: FlowFiber)
}

@ -24,7 +24,7 @@ interface FlowMessaging {
* listen on the send acknowledgement.
*/
@Suspendable
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId)
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId)
/**
* Start the messaging using the [onMessage] message handler.
@ -49,7 +49,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
}
@Suspendable
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) {
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId) {
log.trace { "Sending message $deduplicationId $message to party $party" }
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")

@ -47,6 +47,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null)
}
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
@ -65,6 +67,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
internal var transientValues: TransientReference<TransientValues>? = null
internal var transientState: TransientReference<StateMachineState>? = null
/**
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
*/
override val ourSenderUUID: String?
get() = transientState?.value?.senderUUID
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
return field.get(suppliedValues.value)
@ -168,6 +178,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
fun setLoggingContext() {
context.pushToLoggingContext()
MDC.put("flow-id", id.uuid.toString())
MDC.put("fiber-id", this.getId().toString())
}
@Suspendable
@ -185,7 +196,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true)
Try.Success(result)
} catch (throwable: Throwable) {
logger.warn("Flow threw exception", throwable)
logger.info("Flow threw exception... sending to flow hospital", throwable)
Try.Failure<R>(throwable)
}
val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null
@ -325,7 +336,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
isDbTransactionOpenOnExit = false
)
require(continuation == FlowContinuation.ProcessEvents)
Fiber.unparkDeserialized(this, scheduler)
unpark(SERIALIZER_BLOCKER)
}
setLoggingContext()
return uncheckedCast(processEventsUntilFlowIsResumed(

@ -9,12 +9,17 @@ import net.corda.core.utilities.loggerFor
object PropagatingFlowHospital : FlowHospital {
private val log = loggerFor<PropagatingFlowHospital>()
override fun flowErrored(flowFiber: FlowFiber) {
log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" }
override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>) {
log.debug { "Flow ${flowFiber.id} in state $currentState encountered error" }
flowFiber.scheduleEvent(Event.StartErrorPropagation)
for ((index, error) in errors.withIndex()) {
log.warn("Flow ${flowFiber.id} is propagating error [$index] ", error)
}
}
override fun flowCleaned(flowFiber: FlowFiber) {
throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered")
}
override fun flowRemoved(flowFiber: FlowFiber) {}
}

@ -46,13 +46,15 @@ import java.security.SecureRandom
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.ArrayList
import kotlin.concurrent.withLock
import kotlin.streams.toList
/**
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
* thread actually starts them via [startFlow].
* thread actually starts them via [deliverExternalEvent].
*/
@ThreadSafe
class SingleThreadedStateMachineManager(
@ -90,6 +92,7 @@ class SingleThreadedStateMachineManager(
private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub)
private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null
private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null
private var tokenizableServices: List<Any>? = null
@ -128,7 +131,7 @@ class SingleThreadedStateMachineManager(
resumeRestoredFlows(fibers)
flowMessaging.start { receivedMessage, deduplicationHandler ->
executor.execute {
onSessionMessage(receivedMessage, deduplicationHandler)
deliverExternalEvent(deduplicationHandler.externalCause)
}
}
}
@ -176,7 +179,7 @@ class SingleThreadedStateMachineManager(
}
}
override fun <A> startFlow(
private fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party?,
@ -310,7 +313,73 @@ class SingleThreadedStateMachineManager(
}
}
private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
override fun retryFlowFromSafePoint(currentState: StateMachineState) {
// Get set of external events
val flowId = currentState.flowLogic.runId
val oldFlowLeftOver = mutex.locked { flows[flowId] }?.fiber?.transientValues?.value?.eventQueue
if (oldFlowLeftOver == null) {
logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.")
return
}
val flow = if (currentState.isAnyCheckpointPersisted) {
val serializedCheckpoint = checkpointStorage.getCheckpoint(flowId)
if (serializedCheckpoint == null) {
logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
return
}
val checkpoint = deserializeCheckpoint(serializedCheckpoint)
if (checkpoint == null) {
logger.error("Unable to deserialize database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
return
}
// Resurrect flow
createFlowFromCheckpoint(
id = flowId,
checkpoint = checkpoint,
initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true,
isStartIdempotent = false,
senderUUID = null
)
} else {
// Just flow initiation message
null
}
externalEventMutex.withLock {
if (flow != null) addAndStartFlow(flowId, flow)
// Deliver all the external events from the old flow instance.
val unprocessedExternalEvents = mutableListOf<ExternalEvent>()
do {
val event = oldFlowLeftOver.tryReceive()
if (event is Event.GeneratedByExternalEvent) {
unprocessedExternalEvents += event.deduplicationHandler.externalCause
}
} while (event != null)
val externalEvents = currentState.pendingDeduplicationHandlers.map { it.externalCause } + unprocessedExternalEvents
for (externalEvent in externalEvents) {
deliverExternalEvent(externalEvent)
}
}
}
private val externalEventMutex = ReentrantLock()
override fun deliverExternalEvent(event: ExternalEvent) {
externalEventMutex.withLock {
when (event) {
is ExternalEvent.ExternalMessageEvent -> onSessionMessage(event)
is ExternalEvent.ExternalStartFlowEvent<*> -> onExternalStartFlow(event)
}
}
}
private fun <T> onExternalStartFlow(event: ExternalEvent.ExternalStartFlowEvent<T>) {
val future = startFlow(event.flowLogic, event.context, ourIdentity = null, deduplicationHandler = event.deduplicationHandler)
event.wireUpFuture(future)
}
private fun onSessionMessage(event: ExternalEvent.ExternalMessageEvent) {
val message: ReceivedMessage = event.receivedMessage
val deduplicationHandler: DeduplicationHandler = event.deduplicationHandler
val peer = message.peer
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
@ -384,7 +453,7 @@ class SingleThreadedStateMachineManager(
}
if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID))
deduplicationHandler.afterDatabaseTransaction()
}
}
@ -458,7 +527,8 @@ class SingleThreadedStateMachineManager(
isAnyCheckpointPersisted = false,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = flowLogic
flowLogic = flowLogic,
senderUUID = ourSenderUUID
)
flowStateMachineImpl.transientState = TransientReference(initialState)
mutex.locked {
@ -493,7 +563,7 @@ class SingleThreadedStateMachineManager(
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
return FlowStateMachineImpl.TransientValues(
eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK),
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
resultFuture = resultFuture,
database = database,
transitionExecutor = transitionExecutor,
@ -509,7 +579,8 @@ class SingleThreadedStateMachineManager(
checkpoint: Checkpoint,
isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean,
initialDeduplicationHandler: DeduplicationHandler?
initialDeduplicationHandler: DeduplicationHandler?,
senderUUID: String? = ourSenderUUID
): Flow {
val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>()
@ -524,7 +595,8 @@ class SingleThreadedStateMachineManager(
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = logic
flowLogic = logic,
senderUUID = senderUUID
)
val fiber = FlowStateMachineImpl(id, logic, scheduler)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
@ -542,7 +614,8 @@ class SingleThreadedStateMachineManager(
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = fiber.logic
flowLogic = fiber.logic,
senderUUID = senderUUID
)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
@ -566,9 +639,13 @@ class SingleThreadedStateMachineManager(
startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping"))
logger.trace("Not resuming as SMM is stopping.")
} else {
incrementLiveFibers()
unfinishedFibers.countUp()
flows[id] = flow
val oldFlow = flows.put(id, flow)
if (oldFlow == null) {
incrementLiveFibers()
unfinishedFibers.countUp()
} else {
oldFlow.resultFuture.captureLater(flow.resultFuture)
}
flow.fiber.scheduleEvent(Event.DoRemainingWork)
when (checkpoint.flowState) {
is FlowState.Unstarted -> {
@ -604,7 +681,7 @@ class SingleThreadedStateMachineManager(
private fun makeTransitionExecutor(): TransitionExecutor {
val interceptors = ArrayList<TransitionInterceptor>()
interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) }
interceptors.add { HospitalisingInterceptor(StaffedFlowHospital, it) }
if (serviceHub.configuration.devMode) {
interceptors.add { DumpHistoryOnErrorInterceptor(it) }
}

@ -0,0 +1,127 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.utilities.loggerFor
import java.sql.SQLException
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap
/**
* This hospital consults "staff" to see if they can automatically diagnose and treat flows.
*/
object StaffedFlowHospital : FlowHospital {
private val log = loggerFor<StaffedFlowHospital>()
private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist)
private val patients = ConcurrentHashMap<StateMachineRunId, MedicalHistory>()
val numberOfPatients = patients.size
class MedicalHistory {
val records: MutableList<Record> = mutableListOf()
sealed class Record(val suspendCount: Int) {
class Admitted(val at: Instant, suspendCount: Int) : Record(suspendCount) {
override fun toString() = "Admitted(at=$at, suspendCount=$suspendCount)"
}
class Discharged(val at: Instant, suspendCount: Int, val by: Staff, val error: Throwable) : Record(suspendCount) {
override fun toString() = "Discharged(at=$at, suspendCount=$suspendCount, by=$by)"
}
}
fun notDischargedForTheSameThingMoreThan(max: Int, by: Staff): Boolean {
val lastAdmittanceSuspendCount = (records.last() as MedicalHistory.Record.Admitted).suspendCount
return records.filterIsInstance(MedicalHistory.Record.Discharged::class.java).filter { it.by == by && it.suspendCount == lastAdmittanceSuspendCount }.count() <= max
}
override fun toString(): String = "${this.javaClass.simpleName}(records = $records)"
}
override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>) {
log.info("Flow ${flowFiber.id} admitted to hospital in state $currentState")
val medicalHistory = patients.computeIfAbsent(flowFiber.id) { MedicalHistory() }
medicalHistory.records += MedicalHistory.Record.Admitted(Instant.now(), currentState.checkpoint.numberOfSuspends)
for ((index, error) in errors.withIndex()) {
log.info("Flow ${flowFiber.id} has error [$index]", error)
if (!errorIsDischarged(flowFiber, currentState, error, medicalHistory)) {
// If any error isn't discharged, then we propagate.
log.warn("Flow ${flowFiber.id} error was not discharged, propagating.")
flowFiber.scheduleEvent(Event.StartErrorPropagation)
return
}
}
// If all are discharged, retry.
flowFiber.scheduleEvent(Event.RetryFlowFromSafePoint)
}
private fun errorIsDischarged(flowFiber: FlowFiber, currentState: StateMachineState, error: Throwable, medicalHistory: MedicalHistory): Boolean {
for (staffMember in staff) {
val diagnosis = staffMember.consult(flowFiber, currentState, error, medicalHistory)
if (diagnosis == Diagnosis.DISCHARGE) {
medicalHistory.records += MedicalHistory.Record.Discharged(Instant.now(), currentState.checkpoint.numberOfSuspends, staffMember, error)
log.info("Flow ${flowFiber.id} error discharged from hospital by $staffMember")
return true
}
}
return false
}
// It's okay for flows to be cleaned... we fix them now!
override fun flowCleaned(flowFiber: FlowFiber) {}
override fun flowRemoved(flowFiber: FlowFiber) {
patients.remove(flowFiber.id)
}
enum class Diagnosis {
/**
* Retry from last safe point.
*/
DISCHARGE,
/**
* Please try another member of staff.
*/
NOT_MY_SPECIALTY
}
interface Staff {
fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis
}
/**
* SQL Deadlock detection.
*/
object DeadlockNurse : Staff {
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis {
return if (mentionsDeadlock(newError)) {
Diagnosis.DISCHARGE
} else {
Diagnosis.NOT_MY_SPECIALTY
}
}
private fun mentionsDeadlock(exception: Throwable?): Boolean {
return exception != null && (exception is SQLException && ((exception.message?.toLowerCase()?.contains("deadlock")
?: false)) || mentionsDeadlock(exception.cause))
}
}
/**
* Primary key violation detection for duplicate inserts. Will detect other constraint violations too.
*/
object DuplicateInsertSpecialist : Staff {
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis {
return if (mentionsConstraintViolation(newError) && history.notDischargedForTheSameThingMoreThan(3, this)) {
Diagnosis.DISCHARGE
} else {
Diagnosis.NOT_MY_SPECIALTY
}
}
private fun mentionsConstraintViolation(exception: Throwable?): Boolean {
return exception != null && (exception is org.hibernate.exception.ConstraintViolationException || mentionsConstraintViolation(exception.cause))
}
}
}

@ -4,11 +4,11 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.messaging.DataFeed
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage
import rx.Observable
/**
@ -40,21 +40,6 @@ interface StateMachineManager {
*/
fun stop(allowedUnsuspendedFiberCount: Int)
/**
* Starts a new flow.
*
* @param flowLogic The flow's code.
* @param context The context of the flow.
* @param ourIdentity The identity to use for the flow.
* @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler].
*/
fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>>
/**
* Represents an addition/removal of a state machine.
*/
@ -91,6 +76,12 @@ interface StateMachineManager {
* @return whether the flow existed and was killed.
*/
fun killFlow(id: StateMachineRunId): Boolean
/**
* Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow.
* The event may be replayed if a flow fails and attempts to retry.
*/
fun deliverExternalEvent(event: ExternalEvent)
}
// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call
@ -100,4 +91,38 @@ interface StateMachineManagerInternal {
fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId)
fun removeSessionBindings(sessionIds: Set<SessionId>)
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
fun retryFlowFromSafePoint(currentState: StateMachineState)
}
/**
* Represents an external event that can be injected into the state machine and that might need to be replayed if
* a flow retries. They always have de-duplication handlers to assist with the at-most once logic where required.
*/
interface ExternalEvent {
val deduplicationHandler: DeduplicationHandler
/**
* An external P2P message event.
*/
interface ExternalMessageEvent : ExternalEvent {
val receivedMessage: ReceivedMessage
}
/**
* An external request to start a flow, from the scheduler for example.
*/
interface ExternalStartFlowEvent<T> : ExternalEvent {
val flowLogic: FlowLogic<T>
val context: InvocationContext
/**
* A callback for the state machine to pass back the [Future] associated with the flow start to the submitter.
*/
fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<T>>)
/**
* The future representing the flow start, passed back from the state machine to the submitter of this event.
*/
val future: CordaFuture<FlowStateMachine<T>>
}
}

@ -26,6 +26,7 @@ import net.corda.node.services.messaging.DeduplicationHandler
* possible.
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
* work.
* @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking.
*/
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
@ -37,7 +38,8 @@ data class StateMachineState(
val isTransactionTracked: Boolean,
val isAnyCheckpointPersisted: Boolean,
val isStartIdempotent: Boolean,
val isRemoved: Boolean
val isRemoved: Boolean,
val senderUUID: String?
)
/**

@ -34,7 +34,8 @@ class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : Transiti
(record ?: ArrayList()).apply { add(transitionRecord) }
}
if (nextState.checkpoint.errorState is ErrorState.Errored) {
// Just if we decide to propagate, and not if just on the way to the hospital.
if (nextState.checkpoint.errorState is ErrorState.Errored && nextState.checkpoint.errorState.propagating) {
log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}")
for (error in nextState.checkpoint.errorState.errors) {
log.warn("Flow ${fiber.id} error", error.exception)

@ -26,20 +26,23 @@ class HospitalisingInterceptor(
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
when (nextState.checkpoint.errorState) {
ErrorState.Clean -> {
if (hospitalisedFlows.remove(fiber.id) != null) {
flowHospital.flowCleaned(fiber)
when (nextState.checkpoint.errorState) {
is ErrorState.Clean -> {
if (hospitalisedFlows.remove(fiber.id) != null) {
flowHospital.flowCleaned(fiber)
}
}
is ErrorState.Errored -> {
val exceptionsToHandle = nextState.checkpoint.errorState.errors.map { it.exception }
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
flowHospital.flowErrored(fiber, previousState, exceptionsToHandle)
}
}
}
is ErrorState.Errored -> {
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
flowHospital.flowErrored(fiber)
}
}
}
if (nextState.isRemoved) {
hospitalisedFlows.remove(fiber.id)
flowHospital.flowRemoved(fiber)
}
return Pair(continuation, nextState)
}

@ -46,9 +46,6 @@ class DeliverSessionMessageTransition(
is EndSessionMessage -> endMessageTransition()
}
}
if (!isErrored()) {
persistCheckpoint()
}
// Schedule a DoRemainingWork to check whether the flow needs to be woken up.
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
FlowContinuation.ProcessEvents
@ -73,7 +70,7 @@ class DeliverSessionMessageTransition(
// Send messages that were buffered pending confirmation of session.
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId)
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
}
actions.addAll(sendActions)
currentState = currentState.copy(checkpoint = newCheckpoint)
@ -146,21 +143,6 @@ class DeliverSessionMessageTransition(
}
}
private fun TransitionBuilder.persistCheckpoint() {
// We persist the message as soon as it arrives.
actions.addAll(arrayOf(
Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
))
currentState = currentState.copy(
pendingDeduplicationHandlers = emptyList(),
isAnyCheckpointPersisted = true
)
}
private fun TransitionBuilder.endMessageTransition() {
val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.sessions

@ -46,7 +46,7 @@ class ErrorFlowTransition(
sessions = newSessions
)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions))
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID))
}
// If we're errored but not propagating keep processing events.

@ -216,7 +216,7 @@ class StartedFlowTransition(
}
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null)
actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId))
actions.add(Action.SendInitial(sessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
@ -253,7 +253,7 @@ class StartedFlowTransition(
when (existingSessionState) {
is SessionState.Uninitiated -> {
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message)
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId))
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
@ -270,7 +270,7 @@ class StartedFlowTransition(
is InitiatedSessionState.Live -> {
val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId))
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
Unit
}
InitiatedSessionState.Ended -> {

@ -18,18 +18,19 @@ class TopLevelTransition(
) : Transition {
override fun transition(): TransitionResult {
return when (event) {
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
is Event.Error -> errorTransition(event)
is Event.TransactionCommitted -> transactionCommittedTransition(event)
is Event.SoftShutdown -> softShutdownTransition()
is Event.StartErrorPropagation -> startErrorPropagationTransition()
is Event.EnterSubFlow -> enterSubFlowTransition(event)
is Event.LeaveSubFlow -> leaveSubFlowTransition()
is Event.Suspend -> suspendTransition(event)
is Event.FlowFinish -> flowFinishTransition(event)
is Event.InitiateFlow -> initiateFlowTransition(event)
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
is Event.Error -> errorTransition(event)
is Event.TransactionCommitted -> transactionCommittedTransition(event)
is Event.SoftShutdown -> softShutdownTransition()
is Event.StartErrorPropagation -> startErrorPropagationTransition()
is Event.EnterSubFlow -> enterSubFlowTransition(event)
is Event.LeaveSubFlow -> leaveSubFlowTransition()
is Event.Suspend -> suspendTransition(event)
is Event.FlowFinish -> flowFinishTransition(event)
is Event.InitiateFlow -> initiateFlowTransition(event)
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
is Event.RetryFlowFromSafePoint -> retryFlowFromSafePointTransition(startingState)
}
}
@ -191,7 +192,7 @@ class TopLevelTransition(
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index)
Action.SendExisting(state.peerParty, message, deduplicationId)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
} else {
null
}
@ -230,4 +231,14 @@ class TopLevelTransition(
resumeFlowLogic(event.returnValue)
}
}
private fun retryFlowFromSafePointTransition(startingState: StateMachineState): TransitionResult {
return builder {
// Need to create a flow from the prior checkpoint or flow initiation.
actions.add(Action.CreateTransaction)
actions.add(Action.RetryFlowFromSafePoint(startingState))
actions.add(Action.CommitTransaction)
FlowContinuation.Abort
}
}
}

@ -58,7 +58,7 @@ class UnstartedFlowTransition(
Action.SendExisting(
flowStart.peerSession.counterparty,
sessionMessage,
DeduplicationId.createForNormal(currentState.checkpoint, 0)
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0), currentState.senderUUID)
)
)
}

@ -3,14 +3,21 @@ package net.corda.node.utilities
import com.github.benmanes.caffeine.cache.LoadingCache
import com.github.benmanes.caffeine.cache.Weigher
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.currentDBSession
import java.lang.ref.WeakReference
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference
/**
* Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice the
* behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so
* ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY
* Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice,
* typically this will result in a duplicate insert if this is racing with another transaction. The flow framework will then retry.
*
* This class relies heavily on the fact that compute operations in the cache are atomic for a particular key.
*/
abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
val toPersistentEntityKey: (K) -> EK,
@ -23,7 +30,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
private val log = contextLogger()
}
protected abstract val cache: LoadingCache<K, Optional<V>>
protected abstract val cache: LoadingCache<K, Transactional<V>>
protected val pendingKeys = ConcurrentHashMap<K, MutableSet<DatabaseTransaction>>()
/**
* Returns the value associated with the key, first loading that value from the storage if necessary.
@ -47,32 +55,31 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
return result.map { x -> fromPersistentEntity(x) }.asSequence()
}
private tailrec fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean {
var insertionAttempt = false
var isUnique = true
val existingInCache = cache.get(key) {
// Thread safe, if multiple threads may wait until the first one has loaded.
insertionAttempt = true
// Key wasn't in the cache and might be in the underlying storage.
// Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key.
val existingInDb = store(key, value)
if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key.
isUnique = false
Optional.of(existingInDb)
} else {
Optional.of(value)
}
}!!
if (!insertionAttempt) {
if (existingInCache.isPresent) {
// Key already exists in cache, do nothing.
isUnique = false
} else {
// This happens when the key was queried before with no value associated. We invalidate the cached null
// value and recursively call set again. This is to avoid race conditions where another thread queries after
// the invalidate but before the set.
cache.invalidate(key!!)
return set(key, value, logWarning, store)
private fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean {
// Will be set to true if store says it isn't in the database.
var isUnique = false
cache.asMap().compute(key) { _, oldValue ->
// Always write to the database, unless we can see it's already committed.
when (oldValue) {
is Transactional.InFlight<*, V> -> {
// Someone else is writing, so store away!
// TODO: we can do collision detection here and prevent it happening in the database. But we also have to do deadlock detection, so a bit of work.
isUnique = (store(key, value) == null)
oldValue.apply { alsoWrite(value) }
}
is Transactional.Committed<V> -> oldValue // The value is already globally visible and cached. So do nothing since the values are always the same.
else -> {
// Null or Missing. Store away!
isUnique = (store(key, value) == null)
if (!isUnique && !weAreWriting(key)) {
// If we found a value already in the database, and we were not already writing, then it's already committed but got evicted.
Transactional.Committed(value)
} else {
// Some database transactions, including us, writing, with readers seeing whatever is in the database and writers seeing the (in memory) value.
Transactional.InFlight(this, key, { loadValue(key) }).apply { alsoWrite(value) }
}
}
}
}
if (logWarning && !isUnique) {
@ -93,7 +100,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
/**
* Associates the specified value with the specified key in this map and persists it.
* If the map previously contained a mapping for the key, the old value is not replaced.
* If the map previously contained a committed mapping for the key, the old value is not replaced. It may throw an error from the
* underlying storage if this races with another database transaction to store a value for the same key.
* @return true if added key was unique, otherwise false
*/
fun addWithDuplicatesAllowed(key: K, value: V, logWarning: Boolean = true): Boolean =
@ -116,7 +124,7 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
protected fun loadValue(key: K): V? {
val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
return result?.apply { currentDBSession().detach(result) }?.let(fromPersistentEntity)?.second
}
operator fun contains(key: K) = get(key) != null
@ -132,9 +140,161 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
session.createQuery(deleteQuery).executeUpdate()
cache.invalidateAll()
}
// Helpers to know if transaction(s) are currently writing the given key.
protected fun weAreWriting(key: K): Boolean = pendingKeys.get(key)?.contains(contextTransaction) ?: false
protected fun anyoneWriting(key: K): Boolean = pendingKeys.get(key)?.isNotEmpty() ?: false
// Indicate this database transaction is a writer of this key.
private fun addPendingKey(key: K, databaseTransaction: DatabaseTransaction): Boolean {
var added = true
pendingKeys.compute(key) { k, oldSet ->
if (oldSet == null) {
val newSet = HashSet<DatabaseTransaction>(0)
newSet += databaseTransaction
newSet
} else {
added = oldSet.add(databaseTransaction)
oldSet
}
}
return added
}
// Remove this database transaction as a writer of this key, because the transaction committed or rolled back.
private fun removePendingKey(key: K, databaseTransaction: DatabaseTransaction) {
pendingKeys.compute(key) { k, oldSet ->
if (oldSet == null) {
oldSet
} else {
oldSet -= databaseTransaction
if (oldSet.size == 0) null else oldSet
}
}
}
/**
* Represents a value in the cache, with transaction isolation semantics.
*
* There are 3 states. Globally missing, globally visible, and being written in a transaction somewhere now or in
* the past (and it rolled back).
*/
sealed class Transactional<T> {
abstract val value: T
abstract val isPresent: Boolean
abstract val valueWithoutIsolation: T?
fun orElse(alt: T?) = if (isPresent) value else alt
// Everyone can see it, and database transaction committed.
class Committed<T>(override val value: T) : Transactional<T>() {
override val isPresent: Boolean
get() = true
override val valueWithoutIsolation: T?
get() = value
}
// No one can see it.
class Missing<T>() : Transactional<T>() {
override val value: T
get() = throw NoSuchElementException("Not present")
override val isPresent: Boolean
get() = false
override val valueWithoutIsolation: T?
get() = null
}
// Written in a transaction (uncommitted) somewhere, but there's a small window when this might be seen after commit,
// hence the committed flag.
class InFlight<K, T>(private val map: AppendOnlyPersistentMapBase<K, T, *, *>,
private val key: K,
private val _readerValueLoader: () -> T?,
private val _writerValueLoader: () -> T = { throw IllegalAccessException("No value loader provided") }) : Transactional<T>() {
// A flag to indicate this has now been committed, but hasn't yet been replaced with Committed. This also
// de-duplicates writes of the Committed value to the cache.
private val committed = AtomicBoolean(false)
// What to do if a non-writer needs to see the value and it hasn't yet been committed to the database.
// Can be updated into a no-op once evaluated.
private val readerValueLoader = AtomicReference<() -> T?>(_readerValueLoader)
// What to do if a writer needs to see the value and it hasn't yet been committed to the database.
// Can be updated into a no-op once evaluated.
private val writerValueLoader = AtomicReference<() -> T>(_writerValueLoader)
fun alsoWrite(_value: T) {
// Make the lazy loader the writers see actually just return the value that has been set.
writerValueLoader.set({ _value })
// We make all these vals so that the lambdas do not need a reference to this, and so the onCommit only has a weak ref to the value.
// We want this so that the cache could evict the value (due to memory constraints etc) without the onCommit callback
// retaining what could be a large memory footprint object.
val tx = contextTransaction
val strongKey = key
val weakValue = WeakReference<T>(_value)
val strongComitted = committed
val strongMap = map
if (map.addPendingKey(key, tx)) {
// If the transaction commits, update cache to make globally visible if we're first for this key,
// and then stop saying the transaction is writing the key.
tx.onCommit {
if (strongComitted.compareAndSet(false, true)) {
val dereferencedKey = strongKey
val dereferencedValue = weakValue.get()
if (dereferencedValue != null) {
strongMap.cache.put(dereferencedKey, Committed(dereferencedValue))
}
}
strongMap.removePendingKey(strongKey, tx)
}
// If the transaction rolls back, stop saying this transaction is writing the key.
tx.onRollback {
strongMap.removePendingKey(strongKey, tx)
}
}
}
// Lazy load the value a "writer" would see. If the original loader hasn't been replaced, replace it
// with one that just returns the value once evaluated.
private fun loadAsWriter(): T {
val _value = writerValueLoader.get()()
if (writerValueLoader.get() == _writerValueLoader) {
writerValueLoader.set({ _value })
}
return _value
}
// Lazy load the value a "reader" would see. If the original loader hasn't been replaced, replace it
// with one that just returns the value once evaluated.
private fun loadAsReader(): T? {
val _value = readerValueLoader.get()()
if (readerValueLoader.get() == _readerValueLoader) {
readerValueLoader.set({ _value })
}
return _value
}
// Whether someone reading (only) can see the entry.
private val isPresentAsReader: Boolean get() = (loadAsReader() != null)
// Whether the entry is already written and committed, or we are writing it (and thus it can be seen).
private val isPresentAsWriter: Boolean get() = committed.get() || map.weAreWriting(key)
override val isPresent: Boolean
get() = isPresentAsWriter || isPresentAsReader
// If it is committed or we are writing, reveal the value, potentially lazy loading from the database.
// If none of the above, see what was already in the database, potentially lazily.
override val value: T
get() = if (isPresentAsWriter) loadAsWriter() else if (isPresentAsReader) loadAsReader()!! else throw NoSuchElementException("Not present")
// The value from the perspective of the eviction algorithm of the cache. i.e. we want to reveal memory footprint to it etc.
override val valueWithoutIsolation: T?
get() = if (writerValueLoader.get() != _writerValueLoader) writerValueLoader.get()() else if (readerValueLoader.get() != _writerValueLoader) readerValueLoader.get()() else null
}
}
}
class AppendOnlyPersistentMap<K, V, E, out EK>(
// Open for tests to override
open class AppendOnlyPersistentMap<K, V, E, out EK>(
toPersistentEntityKey: (K) -> EK,
fromPersistentEntity: (E) -> Pair<K, V>,
toPersistentEntity: (key: K, value: V) -> E,
@ -146,26 +306,71 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
toPersistentEntity,
persistentEntityClass) {
//TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size
override val cache = NonInvalidatingCache<K, Optional<V>>(
override val cache = NonInvalidatingCache<K, Transactional<V>>(
bound = cacheBound,
loadFunction = { key -> Optional.ofNullable(loadValue(key)) })
loadFunction = { key: K ->
// This gets called if a value is read and the cache has no Transactional for this key yet.
val value: V? = loadValue(key)
if (value == null) {
// No visible value
if (anyoneWriting(key)) {
// If someone is writing (but not us)
// For those not writing, the value cannot be seen.
// For those writing, they need to re-load the value from the database (which their database transaction CAN see).
Transactional.InFlight<K, V>(this, key, { null }, { loadValue(key)!! })
} else {
// If no one is writing, then the value does not exist.
Transactional.Missing<V>()
}
} else {
// A value was found
if (weAreWriting(key)) {
// If we are writing, it might not be globally visible, and was evicted from the cache.
// For those not writing, they need to check the database again.
// For those writing, they can see the value found.
Transactional.InFlight<K, V>(this, key, { loadValue(key) }, { value })
} else {
// If no one is writing, then make it globally visible.
Transactional.Committed<V>(value)
}
}
})
}
// Same as above, but with weighted values (e.g. memory footprint sensitive).
class WeightBasedAppendOnlyPersistentMap<K, V, E, out EK>(
toPersistentEntityKey: (K) -> EK,
fromPersistentEntity: (E) -> Pair<K, V>,
toPersistentEntity: (key: K, value: V) -> E,
persistentEntityClass: Class<E>,
maxWeight: Long,
weighingFunc: (K, Optional<V>) -> Int
weighingFunc: (K, Transactional<V>) -> Int
) : AppendOnlyPersistentMapBase<K, V, E, EK>(
toPersistentEntityKey,
fromPersistentEntity,
toPersistentEntity,
persistentEntityClass) {
override val cache = NonInvalidatingWeightBasedCache(
override val cache = NonInvalidatingWeightBasedCache<K, Transactional<V>>(
maxWeight = maxWeight,
weigher = Weigher<K, Optional<V>> { key, value -> weighingFunc(key, value) },
loadFunction = { key -> Optional.ofNullable(loadValue(key)) }
)
}
weigher = object : Weigher<K, Transactional<V>> {
override fun weigh(key: K, value: Transactional<V>): Int {
return weighingFunc(key, value)
}
},
loadFunction = { key: K ->
val value: V? = loadValue(key)
if (value == null) {
if (anyoneWriting(key)) {
Transactional.InFlight<K, V>(this, key, { null }, { loadValue(key)!! })
} else {
Transactional.Missing<V>()
}
} else {
if (weAreWriting(key)) {
Transactional.InFlight<K, V>(this, key, { loadValue(key) }, { value })
} else {
Transactional.Committed<V>(value)
}
}
})
}

@ -14,14 +14,15 @@ import net.corda.node.internal.configureDatabase
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.NodePropertiesStore
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.testing.internal.doLookup
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.spectator
import net.corda.testing.node.MockServices
import net.corda.testing.node.TestClock
import org.junit.After
import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
@ -50,7 +51,7 @@ open class NodeSchedulerServiceTestBase {
dedupe.insideDatabaseTransaction()
dedupe.afterDatabaseTransaction()
openFuture<FlowStateMachine<*>>()
}.whenever(it).startFlow(any<FlowLogic<*>>(), any(), any())
}.whenever(it).startFlow(any<ExternalEvent.ExternalStartFlowEvent<*>>())
}
private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also {
doReturn(false).whenever(it).isEnabled()
@ -80,7 +81,7 @@ open class NodeSchedulerServiceTestBase {
protected fun assertStarted(flowLogic: FlowLogic<*>) {
// Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow:
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic), any(), any())
verify(flowStarter, timeout(5000)).startFlow(argForWhich<ExternalEvent.ExternalStartFlowEvent<*>> { this.flowLogic == flowLogic })
}
protected fun assertStarted(event: Event) = assertStarted(event.flowLogic)
@ -112,11 +113,11 @@ class MockScheduledFlowRepository : ScheduledFlowRepository {
}
class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
private val database = rigorousMock<CordaPersistence>().also {
doAnswer {
val block: DatabaseTransaction.() -> Any? = it.getArgument(0)
rigorousMock<DatabaseTransaction>().block()
}.whenever(it).transaction(any())
private val database = configureDatabase(MockServices.makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock())
@After
fun closeDatabase() {
database.close()
}
private val scheduler = NodeSchedulerService(
@ -148,7 +149,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
}).whenever(it).data
}
flows[logicRef] = flowLogic
scheduler.scheduleStateActivity(ssr)
database.transaction {
scheduler.scheduleStateActivity(ssr)
}
}
@Test
@ -207,7 +210,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
fun `test activity due in the future and schedule another for same time then unschedule second`() {
val eventA = schedule(mark + 1.days)
val eventB = schedule(mark + 1.days)
scheduler.unscheduleStateActivity(eventB.stateRef)
database.transaction {
scheduler.unscheduleStateActivity(eventB.stateRef)
}
assertWaitingFor(eventA)
testClock.advanceBy(1.days)
assertStarted(eventA)
@ -217,7 +222,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
fun `test activity due in the future and schedule another for same time then unschedule original`() {
val eventA = schedule(mark + 1.days)
val eventB = schedule(mark + 1.days)
scheduler.unscheduleStateActivity(eventA.stateRef)
database.transaction {
scheduler.unscheduleStateActivity(eventA.stateRef)
}
assertWaitingFor(eventB)
testClock.advanceBy(1.days)
assertStarted(eventB)
@ -225,7 +232,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
@Test
fun `test activity due in the future then unschedule`() {
scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef)
database.transaction {
scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef)
}
testClock.advanceBy(1.days)
}
}

@ -0,0 +1,290 @@
package net.corda.node.services.persistence
import net.corda.core.schemas.MappedSchema
import net.corda.core.utilities.loggerFor
import net.corda.node.internal.configureDatabase
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.junit.After
import org.junit.Assert.*
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.Serializable
import java.util.concurrent.CountDownLatch
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.PersistenceException
@RunWith(Parameterized::class)
class AppendOnlyPersistentMapTest(var scenario: Scenario) {
companion object {
private val scenarios = arrayOf<Scenario>(
Scenario(false, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Fail, Outcome.Fail),
Scenario(false, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success),
Scenario(false, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Fail, Outcome.Success),
Scenario(false, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit),
Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success),
Scenario(false, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Success),
Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.SuccessButErrorOnCommit, Outcome.Fail),
Scenario(true, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Success, Outcome.Success),
Scenario(true, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.SuccessButErrorOnCommit, Outcome.Success),
Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail),
Scenario(true, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit),
Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Fail, Outcome.Success),
Scenario(true, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.Fail),
Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Fail)
)
@Parameterized.Parameters(name = "{0}")
@JvmStatic
fun data(): Array<Array<Scenario>> = scenarios.map { arrayOf(it) }.toTypedArray()
}
enum class ReadOrWrite { Read, Write, WriteDuplicateAllowed }
enum class Outcome { Success, Fail, SuccessButErrorOnCommit }
data class Scenario(val prePopulated: Boolean,
val a: ReadOrWrite,
val b: ReadOrWrite,
val aExpected: Outcome,
val bExpected: Outcome,
val bExpectedIfSingleThreaded: Outcome = bExpected)
private val database = configureDatabase(makeTestDataSourceProperties(),
DatabaseConfig(),
rigorousMock(),
NodeSchemaService(setOf(MappedSchema(AppendOnlyPersistentMapTest::class.java, 1, listOf(PersistentMapEntry::class.java)))))
@After
fun closeDatabase() {
database.close()
}
@Test
fun `concurrent test no purge between A and B`() {
prepopulateIfRequired()
val map = createMap()
val a = TestThread("A", map).apply { start() }
val b = TestThread("B", map).apply { start() }
// Begin A
a.phase1.countDown()
a.await(a::phase2)
// Begin B
b.phase1.countDown()
b.await(b::phase2)
// Commit A
a.phase3.countDown()
a.await(a::phase4)
// Commit B
b.phase3.countDown()
b.await(b::phase4)
// End
a.join()
b.join()
assertTrue(map.pendingKeysIsEmpty())
}
@Test
fun `test no purge with only a single transaction`() {
prepopulateIfRequired()
val map = createMap()
val a = TestThread("A", map, true).apply {
phase1.countDown()
phase3.countDown()
}
val b = TestThread("B", map, true).apply {
phase1.countDown()
phase3.countDown()
}
try {
database.transaction {
a.run()
b.run()
}
} catch (t: PersistenceException) {
// This only helps if thrown on commit, otherwise other latches not counted down.
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome)
}
a.await(a::phase4)
b.await(b::phase4)
assertTrue(map.pendingKeysIsEmpty())
}
@Test
fun `concurrent test purge between A and B`() {
// Writes intentionally do not check the database first, so purging between read and write changes behaviour
val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit))
scenario = remapped[scenario] ?: scenario
prepopulateIfRequired()
val map = createMap()
val a = TestThread("A", map).apply { start() }
val b = TestThread("B", map).apply { start() }
// Begin A
a.phase1.countDown()
a.await(a::phase2)
map.invalidate()
// Begin B
b.phase1.countDown()
b.await(b::phase2)
// Commit A
a.phase3.countDown()
a.await(a::phase4)
// Commit B
b.phase3.countDown()
b.await(b::phase4)
// End
a.join()
b.join()
assertTrue(map.pendingKeysIsEmpty())
}
@Test
fun `test purge mid-way in a single transaction`() {
// Writes intentionally do not check the database first, so purging between read and write changes behaviour
val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit))
scenario = remapped[scenario] ?: scenario
prepopulateIfRequired()
val map = createMap()
val a = TestThread("A", map, true).apply {
phase1.countDown()
phase3.countDown()
}
val b = TestThread("B", map, true).apply {
phase1.countDown()
phase3.countDown()
}
try {
database.transaction {
a.run()
map.invalidate()
b.run()
}
} catch (t: PersistenceException) {
// This only helps if thrown on commit, otherwise other latches not counted down.
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome)
}
a.await(a::phase4)
b.await(b::phase4)
assertTrue(map.pendingKeysIsEmpty())
}
inner class TestThread(name: String, val map: AppendOnlyPersistentMap<Long, String, PersistentMapEntry, Long>, singleThreaded: Boolean = false) : Thread(name) {
private val log = loggerFor<TestThread>()
val readOrWrite = if (name == "A") scenario.a else scenario.b
val outcome = if (name == "A") scenario.aExpected else if (singleThreaded) scenario.bExpectedIfSingleThreaded else scenario.bExpected
val phase1 = latch()
val phase2 = latch()
val phase3 = latch()
val phase4 = latch()
override fun run() {
try {
database.transaction {
await(::phase1)
doActivity()
phase2.countDown()
await(::phase3)
}
} catch (t: PersistenceException) {
// This only helps if thrown on commit, otherwise other latches not counted down.
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, outcome)
}
phase4.countDown()
}
private fun doActivity() {
if (readOrWrite == ReadOrWrite.Read) {
log.info("Reading")
val value = map.get(1)
log.info("Read $value")
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
assertEquals("X", value)
} else {
assertNull(value)
}
} else if (readOrWrite == ReadOrWrite.Write) {
log.info("Writing")
val wasSet = map.set(1, "X")
log.info("Write $wasSet")
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
assertEquals(true, wasSet)
} else {
assertEquals(false, wasSet)
}
} else if (readOrWrite == ReadOrWrite.WriteDuplicateAllowed) {
log.info("Writing with duplicates allowed")
val wasSet = map.addWithDuplicatesAllowed(1, "X")
log.info("Write with duplicates allowed $wasSet")
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
assertEquals(true, wasSet)
} else {
assertEquals(false, wasSet)
}
}
}
private fun latch() = CountDownLatch(1)
fun await(latch: () -> CountDownLatch) {
log.info("Awaiting $latch")
latch().await()
}
}
private fun prepopulateIfRequired() {
if (scenario.prePopulated) {
database.transaction {
val map = createMap()
map.set(1, "X")
}
}
}
@Entity
@javax.persistence.Table(name = "persist_map_test")
class PersistentMapEntry(
@Id
@Column(name = "key")
var key: Long = -1,
@Column(name = "value", length = 16)
var value: String = ""
) : Serializable
class TestMap : AppendOnlyPersistentMap<Long, String, PersistentMapEntry, Long>(
toPersistentEntityKey = { it },
fromPersistentEntity = { Pair(it.key, it.value) },
toPersistentEntity = { key: Long, value: String ->
PersistentMapEntry().apply {
this.key = key
this.value = value
}
},
persistentEntityClass = PersistentMapEntry::class.java
) {
fun pendingKeysIsEmpty() = pendingKeys.isEmpty()
fun invalidate() = cache.invalidateAll()
}
fun createMap() = TestMap()
}

@ -0,0 +1,49 @@
package net.corda.node.services.persistence
import net.corda.node.internal.configureDatabase
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.junit.After
import org.junit.Test
import kotlin.test.assertEquals
class TransactionCallbackTest {
private val database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock())
@After
fun closeDatabase() {
database.close()
}
@Test
fun `onCommit called and onRollback not called on commit`() {
var onCommitCount = 0
var onRollbackCount = 0
database.transaction {
onCommit { onCommitCount++ }
onRollback { onRollbackCount++ }
}
assertEquals(1, onCommitCount)
assertEquals(0, onRollbackCount)
}
@Test
fun `onCommit not called and onRollback called on rollback`() {
class TestException : Exception()
var onCommitCount = 0
var onRollbackCount = 0
try {
database.transaction {
onCommit { onCommitCount++ }
onRollback { onRollbackCount++ }
throw TestException()
}
} catch (e: TestException) {
}
assertEquals(0, onCommitCount)
assertEquals(1, onRollbackCount)
}
}

@ -0,0 +1,166 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.concurrent.CordaFuture
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipients
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.Message
import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.MessagingServiceSpy
import net.corda.testing.node.internal.newContext
import net.corda.testing.node.internal.setMessagingServiceSpy
import org.assertj.core.api.Assertions
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.sql.SQLException
import java.time.Duration
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
class RetryFlowMockTest {
private lateinit var mockNet: InternalMockNetwork
private lateinit var internalNodeA: StartedNode<InternalMockNetwork.MockNode>
private lateinit var internalNodeB: StartedNode<InternalMockNetwork.MockNode>
@Before
fun start() {
mockNet = InternalMockNetwork(threadPerNode = true, cordappPackages = listOf(this.javaClass.`package`.name))
internalNodeA = mockNet.createNode()
internalNodeB = mockNet.createNode()
mockNet.startNodes()
RetryFlow.count = 0
SendAndRetryFlow.count = 0
RetryInsertFlow.count = 0
}
private fun <T> StartedNode<InternalMockNetwork.MockNode>.startFlow(logic: FlowLogic<T>): CordaFuture<T> = this.services.startFlow(logic, this.services.newContext()).getOrThrow().resultFuture
@After
fun cleanUp() {
mockNet.stopNodes()
}
@Test
fun `Single retry`() {
assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get())
assertEquals(2, RetryFlow.count)
}
@Test
fun `Retry forever`() {
Assertions.assertThatThrownBy {
internalNodeA.startFlow(RetryFlow(Int.MAX_VALUE)).getOrThrow()
}.isInstanceOf(LimitedRetryCausingError::class.java)
assertEquals(5, RetryFlow.count)
}
@Test
fun `Retry does not set senderUUID`() {
val messagesSent = mutableListOf<Message>()
val partyB = internalNodeB.info.legalIdentities.first()
internalNodeA.setMessagingServiceSpy(object : MessagingServiceSpy(internalNodeA.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
messagesSent.add(message)
messagingService.send(message, target, retryId)
}
})
internalNodeA.startFlow(SendAndRetryFlow(1, partyB)).get()
assertNotNull(messagesSent.first().senderUUID)
assertNull(messagesSent.last().senderUUID)
assertEquals(2, SendAndRetryFlow.count)
}
@Test
fun `Retry duplicate insert`() {
assertEquals(Unit, internalNodeA.startFlow(RetryInsertFlow(1)).get())
assertEquals(2, RetryInsertFlow.count)
}
@Test
fun `Patient records do not leak in hospital`() {
assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get())
assertEquals(0, StaffedFlowHospital.numberOfPatients)
assertEquals(2, RetryFlow.count)
}
}
class LimitedRetryCausingError : org.hibernate.exception.ConstraintViolationException("Test message", SQLException(), "Test constraint")
class RetryCausingError : SQLException("deadlock")
class RetryFlow(val i: Int) : FlowLogic<Unit>() {
companion object {
var count = 0
}
@Suspendable
override fun call() {
logger.info("Hello $count")
if (count++ < i) {
if (i == Int.MAX_VALUE) {
throw LimitedRetryCausingError()
} else {
throw RetryCausingError()
}
}
}
}
@InitiatingFlow
class SendAndRetryFlow(val i: Int, val other: Party) : FlowLogic<Unit>() {
companion object {
var count = 0
}
@Suspendable
override fun call() {
logger.info("Sending...")
val session = initiateFlow(other)
session.send("Boo")
if (count++ < i) {
throw RetryCausingError()
}
}
}
@InitiatedBy(SendAndRetryFlow::class)
class ReceiveFlow2(val other: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val received = other.receive<String>().unwrap { it }
logger.info("Received... $received")
}
}
class RetryInsertFlow(val i: Int) : FlowLogic<Unit>() {
companion object {
var count = 0
}
@Suspendable
override fun call() {
logger.info("Hello")
doInsert()
// Checkpoint so we roll back to here
FlowLogic.sleep(Duration.ofSeconds(0))
if (count++ < i) {
doInsert()
}
}
private fun doInsert() {
val tx = DBTransactionStorage.DBTransaction("Foo")
contextTransaction.session.save(tx)
}
}

@ -28,7 +28,10 @@ import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.testing.core.*
import net.corda.testing.internal.TEST_TX_TIME
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.vault.*
import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID
import net.corda.testing.internal.vault.DummyLinearContract
import net.corda.testing.internal.vault.DummyLinearStateSchemaV1
import net.corda.testing.internal.vault.VaultFiller
import net.corda.testing.node.MockServices
import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices
import net.corda.testing.node.makeTestIdentityService
@ -171,17 +174,11 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
@JvmField
val expectedEx: ExpectedException = ExpectedException.none()
@Suppress("LeakingThis")
@Rule
@JvmField
val transactionRule = VaultQueryRollbackRule(this)
companion object {
@ClassRule @JvmField
val testSerialization = SerializationEnvironmentRule()
}
/**
* Helper method for generating a Persistent H2 test database
*/
@ -194,7 +191,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
database.close()
}
private fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
_database.transaction {
// create new states
@ -1988,239 +1985,6 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
}
/**
* Dynamic trackBy() tests
*/
@Test
fun trackCashStates_unconsumed() {
val updates = database.transaction {
val updates =
// DOCSTART VaultQueryExample15
vaultService.trackBy<Cash.State>().updates // UNCONSUMED default
// DOCEND VaultQueryExample15
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
close() // transaction needs to be closed to trigger events
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 5) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
@Test
fun trackCashStates_consumed() {
val updates = database.transaction {
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)
val updates = vaultService.trackBy<Cash.State>(criteria).updates
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
consumeCash(100.POUNDS)
// consume more stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
close() // transaction needs to be closed to trigger events
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 1) {}
require(produced.isEmpty()) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 5) {}
require(produced.isEmpty()) {}
}
)
}
}
@Test
fun trackCashStates_all() {
val updates = database.transaction {
val updates =
database.transaction {
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.trackBy<Cash.State>(criteria).updates
}
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(99.POUNDS)
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
close() // transaction needs to be closed to trigger events
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 5) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 1) {}
require(produced.size == 1) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 5) {}
require(produced.isEmpty()) {}
}
)
}
}
@Test
fun trackLinearStates() {
val updates = database.transaction {
// DOCSTART VaultQueryExample16
val (snapshot, updates) = vaultService.trackBy<LinearState>()
// DOCEND VaultQueryExample16
assertThat(snapshot.states).hasSize(0)
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
close() // transaction needs to be closed to trigger events
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 10) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 3) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
@Test
fun trackDealStates() {
val updates = database.transaction {
// DOCSTART VaultQueryExample17
val (snapshot, updates) = vaultService.trackBy<DealState>()
// DOCEND VaultQueryExample17
assertThat(snapshot.states).hasSize(0)
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
close()
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 3) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
@Test
fun unconsumedCashStatesForSpending_single_issuer_reference() {
database.transaction {
@ -2281,10 +2045,241 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
*/
}
class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by vaultQueryTestRule {
class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
companion object {
@ClassRule @JvmField
val vaultQueryTestRule = VaultQueryTestRule()
val delegate = VaultQueryTestRule()
}
@Rule
@JvmField
val vaultQueryTestRule = delegate
/**
* Dynamic trackBy() tests are H2 only, since rollback stops events being emitted.
*/
@Test
fun trackCashStates_unconsumed() {
val updates = database.transaction {
val updates =
// DOCSTART VaultQueryExample15
vaultService.trackBy<Cash.State>().updates // UNCONSUMED default
// DOCEND VaultQueryExample15
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 5) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
@Test
fun trackCashStates_consumed() {
val updates = database.transaction {
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)
val updates = vaultService.trackBy<Cash.State>(criteria).updates
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
consumeCash(100.POUNDS)
// consume more stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 1) {}
require(produced.isEmpty()) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 5) {}
require(produced.isEmpty()) {}
}
)
}
}
@Test
fun trackCashStates_all() {
val updates = database.transaction {
val updates =
database.transaction {
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.trackBy<Cash.State>(criteria).updates
}
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(99.POUNDS)
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 5) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 1) {}
require(produced.size == 1) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.size == 5) {}
require(produced.isEmpty()) {}
}
)
}
}
@Test
fun trackLinearStates() {
val updates = database.transaction {
// DOCSTART VaultQueryExample16
val (snapshot, updates) = vaultService.trackBy<LinearState>()
// DOCEND VaultQueryExample16
assertThat(snapshot.states).hasSize(0)
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 10) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 3) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
@Test
fun trackDealStates() {
val updates = database.transaction {
// DOCSTART VaultQueryExample17
val (snapshot, updates) = vaultService.trackBy<DealState>()
// DOCEND VaultQueryExample17
assertThat(snapshot.states).hasSize(0)
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
// add more cash
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
// add another deal
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
this.session.flush()
// consume stuff
consumeCash(100.DOLLARS)
vaultFiller.consumeDeals(dealStates.toList())
vaultFiller.consumeLinearStates(linearStates.toList())
updates
}
updates.expectEvents {
sequence(
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 3) {}
},
expect { (consumed, produced, flowId) ->
require(flowId == null) {}
require(consumed.isEmpty()) {}
require(produced.size == 1) {}
}
)
}
}
}

@ -5,8 +5,8 @@ import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee
import net.corda.node.internal.configureDatabase
import net.corda.nodeapi.internal.persistence.*
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Test
@ -14,6 +14,7 @@ import rx.Observable
import rx.subjects.PublishSubject
import java.io.Closeable
import java.util.*
import kotlin.test.fail
class ObservablesTests {
private fun isInDatabaseTransaction() = contextTransactionOrNull != null
@ -58,6 +59,72 @@ class ObservablesTests {
assertThat(secondEvent.get()).isEqualTo(0 to false)
}
class TestException : Exception("Synthetic exception for tests") {}
@Test
fun `bufferUntilDatabaseCommit swallows if transaction rolled back`() {
val database = createDatabase()
val source = PublishSubject.create<Int>()
val observable: Observable<Int> = source
val firstEvent = SettableFuture.create<Pair<Int, Boolean>>()
val secondEvent = SettableFuture.create<Pair<Int, Boolean>>()
observable.first().subscribe { firstEvent.set(it to isInDatabaseTransaction()) }
observable.skip(1).first().subscribe { secondEvent.set(it to isInDatabaseTransaction()) }
try {
database.transaction {
val delayedSubject = source.bufferUntilDatabaseCommit()
assertThat(source).isNotEqualTo(delayedSubject)
delayedSubject.onNext(0)
source.onNext(1)
assertThat(firstEvent.isDone).isTrue()
assertThat(secondEvent.isDone).isFalse()
throw TestException()
}
fail("Should not have successfully completed transaction")
} catch (e: TestException) {
}
assertThat(secondEvent.isDone).isFalse()
assertThat(firstEvent.get()).isEqualTo(1 to true)
}
@Test
fun `bufferUntilDatabaseCommit propagates error if transaction rolled back`() {
val database = createDatabase()
val source = PublishSubject.create<Int>()
val observable: Observable<Int> = source
val firstEvent = SettableFuture.create<Pair<Int, Boolean>>()
val secondEvent = SettableFuture.create<Pair<Int, Boolean>>()
observable.first().subscribe({ firstEvent.set(it to isInDatabaseTransaction()) }, {})
observable.skip(1).subscribe({ secondEvent.set(it to isInDatabaseTransaction()) }, {})
observable.skip(1).subscribe({}, { secondEvent.set(2 to isInDatabaseTransaction()) })
try {
database.transaction {
val delayedSubject = source.bufferUntilDatabaseCommit(propagateRollbackAsError = true)
assertThat(source).isNotEqualTo(delayedSubject)
delayedSubject.onNext(0)
source.onNext(1)
assertThat(firstEvent.isDone).isTrue()
assertThat(secondEvent.isDone).isFalse()
throw TestException()
}
fail("Should not have successfully completed transaction")
} catch (e: TestException) {
}
assertThat(secondEvent.isDone).isTrue()
assertThat(firstEvent.get()).isEqualTo(1 to true)
assertThat(secondEvent.get()).isEqualTo(2 to false)
}
@Test
fun `bufferUntilDatabaseCommit delays until transaction closed repeatable`() {
val database = createDatabase()

@ -20,6 +20,8 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.trace
import net.corda.node.services.messaging.*
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.node.internal.InMemoryMessage
@ -108,7 +110,7 @@ class InMemoryMessagingNetwork private constructor(
get() = _receivedMessages
internal val endpoints: List<InternalMockMessagingService> @Synchronized get() = handleEndpointMap.values.toList()
/** Get a [List] of all the [MockMessagingService] endpoints **/
val endpointsExternal: List<MockMessagingService> @Synchronized get() = handleEndpointMap.values.map{ MockMessagingService.createMockMessagingService(it) }.toList()
val endpointsExternal: List<MockMessagingService> @Synchronized get() = handleEndpointMap.values.map { MockMessagingService.createMockMessagingService(it) }.toList()
/**
* Creates a node at the given address: useful if you want to recreate a node to simulate a restart.
@ -135,7 +137,10 @@ class InMemoryMessagingNetwork private constructor(
?: emptyList() //TODO only notary can be distributed?
synchronized(this) {
val node = InMemoryMessaging(manuallyPumped, peerHandle, executor, database)
handleEndpointMap[peerHandle] = node
val oldNode = handleEndpointMap.put(peerHandle, node)
if (oldNode != null) {
node.inheritPendingRedelivery(oldNode)
}
serviceHandles.forEach {
serviceToPeersMapping.getOrPut(it) { LinkedHashSet() }.add(peerHandle)
}
@ -161,7 +166,10 @@ class InMemoryMessagingNetwork private constructor(
@Synchronized
private fun netNodeHasShutdown(peerHandle: PeerHandle) {
handleEndpointMap.remove(peerHandle)
val endpoint = handleEndpointMap[peerHandle]
if (!(endpoint?.hasPendingDeliveries() ?: false)) {
handleEndpointMap.remove(peerHandle)
}
}
@Synchronized
@ -266,6 +274,30 @@ class InMemoryMessagingNetwork private constructor(
return transfer
}
/**
* When a new message handler is added, this implies we have started a new node. The add handler logic uses this to
* push back any un-acknowledged messages for this peer onto the head of the queue (rather than the tail) to maintain message
* delivery order. We push them back because their consumption was not complete and a restarted node would
* see them re-delivered if this was Artemis.
*/
@Synchronized
private fun unPopMessages(transfers: Collection<MessageTransfer>, us: PeerHandle) {
messageReceiveQueues.compute(us) { _, existing ->
if (existing == null) {
LinkedBlockingQueue<MessageTransfer>().apply {
addAll(transfers)
}
} else {
existing.apply {
val drained = mutableListOf<MessageTransfer>()
existing.drainTo(drained)
existing.addAll(transfers)
existing.addAll(drained)
}
}
}
}
private fun pumpSendInternal(transfer: MessageTransfer) {
when (transfer.recipients) {
is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer)
@ -338,6 +370,7 @@ class InMemoryMessagingNetwork private constructor(
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
override val myAddress: PeerHandle get() = peerHandle
override val ourSenderUUID: String = UUID.randomUUID().toString()
private val backgroundThread = if (manuallyPumped) null else
thread(isDaemon = true, name = "In-memory message dispatcher") {
@ -370,10 +403,16 @@ class InMemoryMessagingNetwork private constructor(
Pair(handler, pending)
}
transfers.forEach { pumpSendInternal(it) }
unPopMessages(transfers, peerHandle)
return handler
}
fun inheritPendingRedelivery(other: InMemoryMessaging) {
state.locked {
pendingRedelivery.addAll(other.state.locked { pendingRedelivery })
}
}
override fun removeMessageHandler(registration: MessageHandlerRegistration) {
check(running)
state.locked { check(handlers.remove(registration as Handler)) }
@ -405,8 +444,8 @@ class InMemoryMessagingNetwork private constructor(
override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID)
}
/**
@ -470,13 +509,14 @@ class InMemoryMessagingNetwork private constructor(
database.transaction {
for (handler in deliverTo) {
try {
handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler())
val receivedMessage = transfer.toReceivedMessage()
state.locked { pendingRedelivery.add(transfer) }
handler.callback(receivedMessage, handler, InMemoryDeduplicationHandler(receivedMessage, transfer))
} catch (e: Exception) {
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
}
}
_receivedMessages.onNext(transfer)
processedMessages += transfer.message.uniqueMessageId
messagesInFlight.countDown()
}
}
@ -493,13 +533,23 @@ class InMemoryMessagingNetwork private constructor(
message.uniqueMessageId,
message.debugTimestamp,
sender.name)
}
private class DummyDeduplicationHandler : DeduplicationHandler {
override fun afterDatabaseTransaction() {
}
override fun insideDatabaseTransaction() {
private inner class InMemoryDeduplicationHandler(override val receivedMessage: ReceivedMessage, val transfer: MessageTransfer) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent
get() = this
override val deduplicationHandler: DeduplicationHandler
get() = this
override fun afterDatabaseTransaction() {
this@InMemoryMessaging.state.locked { pendingRedelivery.remove(transfer) }
}
override fun insideDatabaseTransaction() {
processedMessages += transfer.message.uniqueMessageId
}
}
fun hasPendingDeliveries(): Boolean = state.locked { pendingRedelivery.isNotEmpty() }
}
}