mirror of
https://github.com/corda/corda.git
synced 2025-04-11 13:21:26 +00:00
CORDA-1475 CORDA-1465 Allow flows to retry from last checkpoint (#3204)
This commit is contained in:
parent
7cbc316b9d
commit
59fdb3df67
core/src/main/kotlin/net/corda/core/internal
node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence
node/src
integration-test/kotlin/net/corda/node
main/kotlin/net/corda/node
internal
services
api
events
messaging
persistence
statemachine
utilities
test/kotlin/net/corda/node
services
events
persistence
statemachine
vault
utilities
testing/node-driver/src/main/kotlin/net/corda/testing/node
@ -41,4 +41,5 @@ interface FlowStateMachine<FLOWRETURN> {
|
||||
val resultFuture: CordaFuture<FLOWRETURN>
|
||||
val context: InvocationContext
|
||||
val ourIdentity: Party
|
||||
val ourSenderUUID: String?
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import rx.Observable
|
||||
import rx.Subscriber
|
||||
import rx.subjects.PublishSubject
|
||||
import rx.subjects.UnicastSubject
|
||||
import java.io.Closeable
|
||||
import java.sql.Connection
|
||||
@ -67,9 +66,7 @@ class CordaPersistence(
|
||||
}
|
||||
val entityManagerFactory get() = hibernateConfig.sessionFactoryForRegisteredSchemas
|
||||
|
||||
data class Boundary(val txId: UUID)
|
||||
|
||||
internal val transactionBoundaries = PublishSubject.create<Boundary>().toSerialized()
|
||||
data class Boundary(val txId: UUID, val success: Boolean)
|
||||
|
||||
init {
|
||||
// Found a unit test that was forgetting to close the database transactions. When you close() on the top level
|
||||
@ -186,15 +183,19 @@ class CordaPersistence(
|
||||
*
|
||||
* For examples, see the call hierarchy of this function.
|
||||
*/
|
||||
fun <T : Any> rx.Observer<T>.bufferUntilDatabaseCommit(): rx.Observer<T> {
|
||||
val currentTxId = contextTransaction.id
|
||||
val databaseTxBoundary: Observable<CordaPersistence.Boundary> = contextDatabase.transactionBoundaries.first { it.txId == currentTxId }
|
||||
fun <T : Any> rx.Observer<T>.bufferUntilDatabaseCommit(propagateRollbackAsError: Boolean = false): rx.Observer<T> {
|
||||
val currentTx = contextTransaction
|
||||
val subject = UnicastSubject.create<T>()
|
||||
val databaseTxBoundary: Observable<CordaPersistence.Boundary> = currentTx.boundary.filter { it.success }
|
||||
if (propagateRollbackAsError) {
|
||||
currentTx.boundary.filter { !it.success }.subscribe { this.onError(DatabaseTransactionRolledBackException(it.txId)) }
|
||||
}
|
||||
subject.delaySubscription(databaseTxBoundary).subscribe(this)
|
||||
databaseTxBoundary.doOnCompleted { subject.onCompleted() }
|
||||
return subject
|
||||
}
|
||||
|
||||
class DatabaseTransactionRolledBackException(txId: UUID) : Exception("Database transaction $txId was rolled back")
|
||||
|
||||
// A subscriber that delegates to multiple others, wrapping a database transaction around the combination.
|
||||
private class DatabaseTransactionWrappingSubscriber<U>(private val db: CordaPersistence?) : Subscriber<U>() {
|
||||
// Some unsubscribes happen inside onNext() so need something that supports concurrent modification.
|
||||
|
@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.persistence
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import org.hibernate.Session
|
||||
import org.hibernate.Transaction
|
||||
import rx.subjects.PublishSubject
|
||||
import java.sql.Connection
|
||||
import java.util.*
|
||||
|
||||
@ -35,11 +36,16 @@ class DatabaseTransaction(
|
||||
|
||||
val session: Session by sessionDelegate
|
||||
private lateinit var hibernateTransaction: Transaction
|
||||
|
||||
internal val boundary = PublishSubject.create<CordaPersistence.Boundary>()
|
||||
private var committed = false
|
||||
|
||||
fun commit() {
|
||||
if (sessionDelegate.isInitialized()) {
|
||||
hibernateTransaction.commit()
|
||||
}
|
||||
connection.commit()
|
||||
committed = true
|
||||
}
|
||||
|
||||
fun rollback() {
|
||||
@ -58,7 +64,15 @@ class DatabaseTransaction(
|
||||
connection.close()
|
||||
contextTransactionOrNull = outerTransaction
|
||||
if (outerTransaction == null) {
|
||||
database.transactionBoundaries.onNext(CordaPersistence.Boundary(id))
|
||||
boundary.onNext(CordaPersistence.Boundary(id, committed))
|
||||
}
|
||||
}
|
||||
|
||||
fun onCommit(callback: () -> Unit) {
|
||||
boundary.filter { it.success }.subscribe { callback() }
|
||||
}
|
||||
|
||||
fun onRollback(callback: () -> Unit) {
|
||||
boundary.filter { !it.success }.subscribe { callback() }
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,158 @@
|
||||
package net.corda.node.flows
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.client.rpc.CordaRPCClient
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.services.Permissions
|
||||
import net.corda.testing.core.singleIdentity
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.driver
|
||||
import net.corda.testing.driver.internal.RandomFree
|
||||
import net.corda.testing.node.User
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.lang.management.ManagementFactory
|
||||
import java.sql.SQLException
|
||||
import java.util.*
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotNull
|
||||
|
||||
|
||||
class FlowRetryTest {
|
||||
@Before
|
||||
fun resetCounters() {
|
||||
InitiatorFlow.seen.clear()
|
||||
InitiatedFlow.seen.clear()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `flows continue despite errors`() {
|
||||
val numSessions = 2
|
||||
val numIterations = 10
|
||||
val user = User("mark", "dadada", setOf(Permissions.startFlow<InitiatorFlow>()))
|
||||
val result: Any? = driver(DriverParameters(isDebug = true, startNodesInProcess = isQuasarAgentSpecified(),
|
||||
portAllocation = RandomFree)) {
|
||||
|
||||
val nodeAHandle = startNode(rpcUsers = listOf(user)).getOrThrow()
|
||||
val nodeBHandle = startNode(rpcUsers = listOf(user)).getOrThrow()
|
||||
|
||||
val result = CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
|
||||
it.proxy.startFlow(::InitiatorFlow, numSessions, numIterations, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow()
|
||||
}
|
||||
result
|
||||
}
|
||||
assertNotNull(result)
|
||||
assertEquals("$numSessions:$numIterations", result)
|
||||
}
|
||||
}
|
||||
|
||||
fun isQuasarAgentSpecified(): Boolean {
|
||||
val jvmArgs = ManagementFactory.getRuntimeMXBean().inputArguments
|
||||
return jvmArgs.any { it.startsWith("-javaagent:") && it.contains("quasar") }
|
||||
}
|
||||
|
||||
class ExceptionToCauseRetry : SQLException("deadlock")
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class InitiatorFlow(private val sessionsCount: Int, private val iterationsCount: Int, private val other: Party) : FlowLogic<Any>() {
|
||||
companion object {
|
||||
object FIRST_STEP : ProgressTracker.Step("Step one")
|
||||
|
||||
fun tracker() = ProgressTracker(FIRST_STEP)
|
||||
|
||||
val seen = Collections.synchronizedSet(HashSet<Visited>())
|
||||
|
||||
fun visit(sessionNum: Int, iterationNum: Int, step: Step) {
|
||||
val visited = Visited(sessionNum, iterationNum, step)
|
||||
if (visited !in seen) {
|
||||
seen += visited
|
||||
throw ExceptionToCauseRetry()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override val progressTracker = tracker()
|
||||
|
||||
@Suspendable
|
||||
override fun call(): Any {
|
||||
progressTracker.currentStep = FIRST_STEP
|
||||
var received: Any? = null
|
||||
visit(-1, -1, Step.First)
|
||||
for (sessionNum in 1..sessionsCount) {
|
||||
visit(sessionNum, -1, Step.BeforeInitiate)
|
||||
val session = initiateFlow(other)
|
||||
visit(sessionNum, -1, Step.AfterInitiate)
|
||||
session.send(SessionInfo(sessionNum, iterationsCount))
|
||||
visit(sessionNum, -1, Step.AfterInitiateSendReceive)
|
||||
for (iteration in 1..iterationsCount) {
|
||||
visit(sessionNum, iteration, Step.BeforeSend)
|
||||
logger.info("A Sending $sessionNum:$iteration")
|
||||
session.send("$sessionNum:$iteration")
|
||||
visit(sessionNum, iteration, Step.AfterSend)
|
||||
received = session.receive<Any>().unwrap { it }
|
||||
visit(sessionNum, iteration, Step.AfterReceive)
|
||||
logger.info("A Got $sessionNum:$iteration")
|
||||
}
|
||||
doSleep()
|
||||
}
|
||||
return received!!
|
||||
}
|
||||
|
||||
// This non-flow-friendly sleep triggered a bug with session end messages and non-retryable checkpoints.
|
||||
private fun doSleep() {
|
||||
Thread.sleep(2000)
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(InitiatorFlow::class)
|
||||
class InitiatedFlow(val session: FlowSession) : FlowLogic<Any>() {
|
||||
companion object {
|
||||
object FIRST_STEP : ProgressTracker.Step("Step one")
|
||||
|
||||
fun tracker() = ProgressTracker(FIRST_STEP)
|
||||
|
||||
val seen = Collections.synchronizedSet(HashSet<Visited>())
|
||||
|
||||
fun visit(sessionNum: Int, iterationNum: Int, step: Step) {
|
||||
val visited = Visited(sessionNum, iterationNum, step)
|
||||
if (visited !in seen) {
|
||||
seen += visited
|
||||
throw ExceptionToCauseRetry()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override val progressTracker = tracker()
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
progressTracker.currentStep = FIRST_STEP
|
||||
visit(-1, -1, Step.AfterInitiate)
|
||||
val sessionInfo = session.receive<SessionInfo>().unwrap { it }
|
||||
visit(sessionInfo.sessionNum, -1, Step.AfterInitiateSendReceive)
|
||||
for (iteration in 1..sessionInfo.iterationsCount) {
|
||||
visit(sessionInfo.sessionNum, iteration, Step.BeforeReceive)
|
||||
val got = session.receive<Any>().unwrap { it }
|
||||
visit(sessionInfo.sessionNum, iteration, Step.AfterReceive)
|
||||
logger.info("B Got $got")
|
||||
logger.info("B Sending $got")
|
||||
visit(sessionInfo.sessionNum, iteration, Step.BeforeSend)
|
||||
session.send(got)
|
||||
visit(sessionInfo.sessionNum, iteration, Step.AfterSend)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@CordaSerializable
|
||||
data class SessionInfo(val sessionNum: Int, val iterationsCount: Int)
|
||||
|
||||
enum class Step { First, BeforeInitiate, AfterInitiate, AfterInitiateSendReceive, BeforeSend, AfterSend, BeforeReceive, AfterReceive }
|
||||
|
||||
data class Visited(val sessionNum: Int, val iterationNum: Int, val step: Step)
|
@ -25,6 +25,7 @@ import net.corda.core.node.services.vault.QueryCriteria
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
import net.corda.core.utilities.NonEmptySet
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.seconds
|
||||
import net.corda.testMessage.ScheduledState
|
||||
import net.corda.testMessage.SpentState
|
||||
import net.corda.testing.contracts.DummyContract
|
||||
@ -96,7 +97,7 @@ class ScheduledFlowIntegrationTests {
|
||||
val aliceClient = CordaRPCClient(alice.rpcAddress).start(rpcUser.username, rpcUser.password)
|
||||
val bobClient = CordaRPCClient(bob.rpcAddress).start(rpcUser.username, rpcUser.password)
|
||||
|
||||
val scheduledFor = Instant.now().plusSeconds(20)
|
||||
val scheduledFor = Instant.now().plusSeconds(10)
|
||||
val initialiseFutures = mutableListOf<CordaFuture<*>>()
|
||||
for (i in 0 until N) {
|
||||
initialiseFutures.add(aliceClient.proxy.startFlow(::InsertInitialStateFlow, bob.nodeInfo.legalIdentities.first(), defaultNotaryIdentity, i, scheduledFor).returnValue)
|
||||
@ -111,6 +112,9 @@ class ScheduledFlowIntegrationTests {
|
||||
}
|
||||
spendAttemptFutures.getOrThrowAll()
|
||||
|
||||
// TODO: the queries below are not atomic so we need to allow enough time for the scheduler to finish. Would be better to query scheduler.
|
||||
Thread.sleep(20.seconds.toMillis())
|
||||
|
||||
val aliceStates = aliceClient.proxy.vaultQuery(ScheduledState::class.java).states.filter { it.state.data.processed }
|
||||
val aliceSpentStates = aliceClient.proxy.vaultQuery(SpentState::class.java).states
|
||||
|
||||
|
@ -985,8 +985,37 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
|
||||
}
|
||||
|
||||
internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
|
||||
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture<FlowStateMachine<T>> {
|
||||
return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler)
|
||||
override fun <T> startFlow(event: ExternalEvent.ExternalStartFlowEvent<T>): CordaFuture<FlowStateMachine<T>> {
|
||||
smm.deliverExternalEvent(event)
|
||||
return event.future
|
||||
}
|
||||
|
||||
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> {
|
||||
val startFlowEvent = object : ExternalEvent.ExternalStartFlowEvent<T>, DeduplicationHandler {
|
||||
override fun insideDatabaseTransaction() {}
|
||||
|
||||
override fun afterDatabaseTransaction() {}
|
||||
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val deduplicationHandler: DeduplicationHandler
|
||||
get() = this
|
||||
|
||||
override val flowLogic: FlowLogic<T>
|
||||
get() = logic
|
||||
override val context: InvocationContext
|
||||
get() = context
|
||||
|
||||
override fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<T>>) {
|
||||
_future.captureLater(flowFuture)
|
||||
}
|
||||
|
||||
private val _future = openFuture<FlowStateMachine<T>>()
|
||||
override val future: CordaFuture<FlowStateMachine<T>>
|
||||
get() = _future
|
||||
|
||||
}
|
||||
return startFlow(startFlowEvent)
|
||||
}
|
||||
|
||||
override fun <T> invokeFlowAsync(
|
||||
|
@ -20,6 +20,12 @@ interface CheckpointStorage {
|
||||
*/
|
||||
fun removeCheckpoint(id: StateMachineRunId): Boolean
|
||||
|
||||
/**
|
||||
* Load an existing checkpoint from the store.
|
||||
* @return the checkpoint, still in serialized form, or null if not found.
|
||||
*/
|
||||
fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>?
|
||||
|
||||
/**
|
||||
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
|
||||
* underlying database connection is closed, so any processing should happen before it is closed.
|
||||
|
@ -19,9 +19,9 @@ import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.internal.cordapp.CordappProviderInternal
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.network.NetworkMapUpdater
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
|
||||
@ -134,11 +134,17 @@ interface ServiceHubInternal : ServiceHub {
|
||||
interface FlowStarter {
|
||||
|
||||
/**
|
||||
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
|
||||
* Starts an already constructed flow. Note that you must be on the server thread to call this method. This method
|
||||
* just synthesizes an [ExternalEvent.ExternalStartFlowEvent] and calls the method below.
|
||||
* @param context indicates who started the flow, see: [InvocationContext].
|
||||
* @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler]
|
||||
*/
|
||||
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture<FlowStateMachine<T>>
|
||||
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>>
|
||||
|
||||
/**
|
||||
* Starts a flow as described by an [ExternalEvent.ExternalStartFlowEvent]. If a transient error
|
||||
* occurs during invocation, it will re-attempt to start the flow.
|
||||
*/
|
||||
fun <T> startFlow(event: ExternalEvent.ExternalStartFlowEvent<T>): CordaFuture<FlowStateMachine<T>>
|
||||
|
||||
/**
|
||||
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
|
||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.events
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.context.InvocationOrigin
|
||||
import net.corda.core.contracts.SchedulableState
|
||||
@ -10,11 +11,9 @@ import net.corda.core.contracts.ScheduledStateRef
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowLogicRefFactory
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.concurrent.flatMap
|
||||
import net.corda.core.internal.join
|
||||
import net.corda.core.internal.until
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.node.ServicesForResolution
|
||||
import net.corda.core.schemas.PersistentStateRef
|
||||
import net.corda.core.serialization.SingletonSerializeAsToken
|
||||
@ -26,8 +25,10 @@ import net.corda.node.services.api.FlowStarter
|
||||
import net.corda.node.services.api.NodePropertiesStore
|
||||
import net.corda.node.services.api.SchedulerService
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import org.apache.mina.util.ConcurrentHashSet
|
||||
import org.slf4j.Logger
|
||||
@ -156,29 +157,31 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
|
||||
override fun scheduleStateActivity(action: ScheduledStateRef) {
|
||||
log.trace { "Schedule $action" }
|
||||
if (!schedulerRepo.merge(action)) {
|
||||
// Only increase the number of unfinished schedules if the state didn't already exist on the queue
|
||||
unfinishedSchedules.countUp()
|
||||
}
|
||||
mutex.locked {
|
||||
if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) {
|
||||
// We are earliest
|
||||
rescheduleWakeUp()
|
||||
} else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) {
|
||||
// We were earliest but might not be any more
|
||||
rescheduleWakeUp()
|
||||
// Only increase the number of unfinished schedules if the state didn't already exist on the queue
|
||||
val countUp = !schedulerRepo.merge(action)
|
||||
contextTransaction.onCommit {
|
||||
if (countUp) unfinishedSchedules.countUp()
|
||||
mutex.locked {
|
||||
if (action.scheduledAt < nextScheduledAction?.scheduledAt ?: Instant.MAX) {
|
||||
// We are earliest
|
||||
rescheduleWakeUp()
|
||||
} else if (action.ref == nextScheduledAction?.ref && action.scheduledAt != nextScheduledAction?.scheduledAt) {
|
||||
// We were earliest but might not be any more
|
||||
rescheduleWakeUp()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun unscheduleStateActivity(ref: StateRef) {
|
||||
log.trace { "Unschedule $ref" }
|
||||
if (startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref)) {
|
||||
unfinishedSchedules.countDown()
|
||||
}
|
||||
mutex.locked {
|
||||
if (nextScheduledAction?.ref == ref) {
|
||||
rescheduleWakeUp()
|
||||
val countDown = startingStateRefs.all { it.ref != ref } && schedulerRepo.delete(ref)
|
||||
contextTransaction.onCommit {
|
||||
if (countDown) unfinishedSchedules.countDown()
|
||||
mutex.locked {
|
||||
if (nextScheduledAction?.ref == ref) {
|
||||
rescheduleWakeUp()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -227,7 +230,12 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
schedulerTimerExecutor.join()
|
||||
}
|
||||
|
||||
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler {
|
||||
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef, override val flowLogic: FlowLogic<Any?>, override val context: InvocationContext) : DeduplicationHandler, ExternalEvent.ExternalStartFlowEvent<Any?> {
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val deduplicationHandler: FlowStartDeduplicationHandler
|
||||
get() = this
|
||||
|
||||
override fun insideDatabaseTransaction() {
|
||||
schedulerRepo.delete(scheduledState.ref)
|
||||
}
|
||||
@ -239,6 +247,18 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}($scheduledState)"
|
||||
}
|
||||
|
||||
override fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<Any?>>) {
|
||||
_future.captureLater(flowFuture)
|
||||
val future = _future.flatMap { it.resultFuture }
|
||||
future.then {
|
||||
unfinishedSchedules.countDown()
|
||||
}
|
||||
}
|
||||
|
||||
private val _future = openFuture<FlowStateMachine<Any?>>()
|
||||
override val future: CordaFuture<FlowStateMachine<Any?>>
|
||||
get() = _future
|
||||
}
|
||||
|
||||
private fun onTimeReached(scheduledState: ScheduledStateRef) {
|
||||
@ -250,11 +270,8 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
flowName = scheduledFlow.javaClass.name
|
||||
// TODO refactor the scheduler to store and propagate the original invocation context
|
||||
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
|
||||
val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState)
|
||||
val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture }
|
||||
future.then {
|
||||
unfinishedSchedules.countDown()
|
||||
}
|
||||
val startFlowEvent = FlowStartDeduplicationHandler(scheduledState, scheduledFlow, context)
|
||||
flowStarter.startFlow(startFlowEvent)
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
|
@ -10,6 +10,8 @@ import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import java.time.Instant
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
@ -25,6 +27,12 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
*/
|
||||
@ThreadSafe
|
||||
interface MessagingService {
|
||||
/**
|
||||
* A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence
|
||||
* number for message de-duplication at the recipient.
|
||||
*/
|
||||
val ourSenderUUID: String
|
||||
|
||||
/**
|
||||
* The provided function will be invoked for each received message whose topic and session matches. The callback
|
||||
* will run on the main server thread provided when the messaging service is constructed, and a database
|
||||
@ -93,11 +101,12 @@ interface MessagingService {
|
||||
/**
|
||||
* Returns an initialised [Message] with the current time, etc, already filled in.
|
||||
*
|
||||
* @param topicSession identifier for the topic and session the message is sent to.
|
||||
* @param additionalProperties optional additional message headers.
|
||||
* @param topic identifier for the topic the message is sent to.
|
||||
* @param data the payload for the message.
|
||||
* @param deduplicationId optional message deduplication ID including sender identifier.
|
||||
* @param additionalHeaders optional additional message headers.
|
||||
*/
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map<String, String> = emptyMap()): Message
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()): Message
|
||||
|
||||
/** Given information about either a specific node or a service returns its corresponding address */
|
||||
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
|
||||
@ -106,7 +115,7 @@ interface MessagingService {
|
||||
val myAddress: SingleMessageRecipient
|
||||
}
|
||||
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
|
||||
|
||||
interface MessageHandlerRegistration
|
||||
|
||||
@ -152,15 +161,17 @@ object TopicStringValidator {
|
||||
}
|
||||
|
||||
/**
|
||||
* This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done
|
||||
* This handler is used to implement exactly-once delivery of an external event on top of an at-least-once delivery. This is done
|
||||
* using two hooks that are called from the event processor, one called from the database transaction committing the
|
||||
* side-effect caused by the event, and another one called after the transaction has committed successfully.
|
||||
* side-effect caused by the external event, and another one called after the transaction has committed successfully.
|
||||
*
|
||||
* For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later
|
||||
* deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries.
|
||||
*
|
||||
* We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the
|
||||
* to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping.
|
||||
*
|
||||
* It holds a reference back to the causing external event.
|
||||
*/
|
||||
interface DeduplicationHandler {
|
||||
/**
|
||||
@ -174,6 +185,11 @@ interface DeduplicationHandler {
|
||||
* cleanup/acknowledgement/stopping of retries.
|
||||
*/
|
||||
fun afterDatabaseTransaction()
|
||||
|
||||
/**
|
||||
* The external event for which we are trying to reduce from at-least-once delivery to exactly-once.
|
||||
*/
|
||||
val externalCause: ExternalEvent
|
||||
}
|
||||
|
||||
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit
|
||||
|
@ -8,7 +8,6 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import java.io.Serializable
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
@ -18,8 +17,6 @@ import javax.persistence.Id
|
||||
* Encapsulate the de-duplication logic.
|
||||
*/
|
||||
class P2PMessageDeduplicator(private val database: CordaPersistence) {
|
||||
val ourSenderUUID = UUID.randomUUID().toString()
|
||||
|
||||
// A temporary in-memory set of deduplication IDs and associated high water mark details.
|
||||
// When we receive a message we don't persist the ID immediately,
|
||||
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
|
||||
|
@ -23,6 +23,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multipl
|
||||
import net.corda.node.services.api.NetworkMapCacheInternal
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.PersistentMap
|
||||
import net.corda.nodeapi.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport
|
||||
@ -124,7 +126,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
)
|
||||
}
|
||||
|
||||
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
|
||||
class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
|
||||
override val debugTimestamp: Instant = Instant.now()
|
||||
override fun toString() = "$topic#${String(data.bytes)}"
|
||||
}
|
||||
@ -158,6 +160,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
|
||||
|
||||
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
|
||||
override val ourSenderUUID = UUID.randomUUID().toString()
|
||||
|
||||
private val messageRedeliveryDelaySeconds = config.p2pMessagingRetry.messageRedeliveryDelay.seconds
|
||||
private val state = ThreadBox(InnerState())
|
||||
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
|
||||
@ -227,7 +231,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
producer!!,
|
||||
versionInfo,
|
||||
this@P2PMessagingClient,
|
||||
ourSenderUUID = deduplicator.ourSenderUUID
|
||||
ourSenderUUID = ourSenderUUID
|
||||
)
|
||||
|
||||
registerBridgeControl(bridgeSession!!, inboxes.toList())
|
||||
@ -435,18 +439,23 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler {
|
||||
private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val deduplicationHandler: MessageDeduplicationHandler
|
||||
get() = this
|
||||
|
||||
override fun insideDatabaseTransaction() {
|
||||
deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId)
|
||||
deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId)
|
||||
}
|
||||
|
||||
override fun afterDatabaseTransaction() {
|
||||
deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId)
|
||||
deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId)
|
||||
messagingExecutor!!.acknowledge(artemisMessage)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})"
|
||||
return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})"
|
||||
}
|
||||
}
|
||||
|
||||
@ -610,8 +619,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
handlers.remove(registration.topic)
|
||||
}
|
||||
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders)
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders)
|
||||
}
|
||||
|
||||
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
|
||||
|
@ -10,9 +10,9 @@ import net.corda.nodeapi.internal.persistence.currentDBSession
|
||||
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.Serializable
|
||||
import java.util.*
|
||||
import java.util.stream.Stream
|
||||
import java.io.Serializable
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Id
|
||||
@ -53,6 +53,11 @@ class DBCheckpointStorage : CheckpointStorage {
|
||||
return session.createQuery(delete).executeUpdate() > 0
|
||||
}
|
||||
|
||||
override fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>? {
|
||||
val bytes = currentDBSession().get(DBCheckpoint::class.java, id.uuid.toString())?.checkpoint ?: return null
|
||||
return SerializedBytes<Checkpoint>(bytes)
|
||||
}
|
||||
|
||||
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> {
|
||||
val session = currentDBSession()
|
||||
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
|
||||
|
@ -21,7 +21,6 @@ import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.io.Serializable
|
||||
import java.util.*
|
||||
import javax.persistence.*
|
||||
|
||||
// cache value type to just store the immutable bits of a signed transaction plus conversion helpers
|
||||
@ -71,11 +70,11 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
|
||||
// to the memory pressure at all here.
|
||||
private const val transactionSignatureOverheadEstimate = 1024
|
||||
|
||||
private fun weighTx(tx: Optional<TxCacheValue>): Int {
|
||||
if (!tx.isPresent) {
|
||||
private fun weighTx(tx: AppendOnlyPersistentMapBase.Transactional<TxCacheValue>): Int {
|
||||
val actTx = tx.valueWithoutIsolation
|
||||
if (actTx == null) {
|
||||
return 0
|
||||
}
|
||||
val actTx = tx.get()
|
||||
return actTx.second.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.first.size
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ sealed class Action {
|
||||
data class SendInitial(
|
||||
val party: Party,
|
||||
val initialise: InitialSessionMessage,
|
||||
val deduplicationId: DeduplicationId
|
||||
val deduplicationId: SenderDeduplicationId
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
@ -33,7 +33,7 @@ sealed class Action {
|
||||
data class SendExisting(
|
||||
val peerParty: Party,
|
||||
val message: ExistingSessionMessage,
|
||||
val deduplicationId: DeduplicationId
|
||||
val deduplicationId: SenderDeduplicationId
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
@ -62,7 +62,8 @@ sealed class Action {
|
||||
*/
|
||||
data class PropagateErrors(
|
||||
val errorMessages: List<ErrorSessionMessage>,
|
||||
val sessions: List<SessionState.Initiated>
|
||||
val sessions: List<SessionState.Initiated>,
|
||||
val senderUUID: String?
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
@ -129,6 +130,11 @@ sealed class Action {
|
||||
* Release soft locks associated with given ID (currently the flow ID).
|
||||
*/
|
||||
data class ReleaseSoftLocks(val uuid: UUID?) : Action()
|
||||
|
||||
/**
|
||||
* Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details.
|
||||
*/
|
||||
data class RetryFlowFromSafePoint(val currentState: StateMachineState) : Action()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -73,6 +73,7 @@ class ActionExecutorImpl(
|
||||
is Action.CommitTransaction -> executeCommitTransaction()
|
||||
is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action)
|
||||
is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action)
|
||||
is Action.RetryFlowFromSafePoint -> executeRetryFlowFromSafePoint(action)
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,7 +126,7 @@ class ActionExecutorImpl(
|
||||
@Suspendable
|
||||
private fun executePropagateErrors(action: Action.PropagateErrors) {
|
||||
action.errorMessages.forEach { (exception) ->
|
||||
log.debug("Propagating error", exception)
|
||||
log.warn("Propagating error", exception)
|
||||
}
|
||||
for (sessionState in action.sessions) {
|
||||
// We cannot propagate if the session isn't live.
|
||||
@ -137,7 +138,7 @@ class ActionExecutorImpl(
|
||||
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
|
||||
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
|
||||
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId)
|
||||
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -226,6 +227,10 @@ class ActionExecutorImpl(
|
||||
)
|
||||
}
|
||||
|
||||
private fun executeRetryFlowFromSafePoint(action: Action.RetryFlowFromSafePoint) {
|
||||
stateMachineManager.retryFlowFromSafePoint(action.currentState)
|
||||
}
|
||||
|
||||
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
|
||||
return checkpoint.serialize(context = checkpointSerializationContext)
|
||||
}
|
||||
|
@ -45,3 +45,9 @@ data class DeduplicationId(val toString: String) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be
|
||||
* null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID.
|
||||
*/
|
||||
data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?)
|
@ -31,9 +31,9 @@ sealed class Event {
|
||||
*/
|
||||
data class DeliverSessionMessage(
|
||||
val sessionMessage: ExistingSessionMessage,
|
||||
val deduplicationHandler: DeduplicationHandler,
|
||||
override val deduplicationHandler: DeduplicationHandler,
|
||||
val sender: Party
|
||||
) : Event()
|
||||
) : Event(), GeneratedByExternalEvent
|
||||
|
||||
/**
|
||||
* Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error.
|
||||
@ -133,4 +133,19 @@ sealed class Event {
|
||||
* @param returnValue the result of the operation.
|
||||
*/
|
||||
data class AsyncOperationCompletion(val returnValue: Any?) : Event()
|
||||
|
||||
/**
|
||||
* Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details.
|
||||
*/
|
||||
object RetryFlowFromSafePoint : Event() {
|
||||
override fun toString() = "RetryFlowFromSafePoint"
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates that an event was generated by an external event and that external event needs to be replayed if we retry the flow,
|
||||
* even if it has not yet been processed and placed on the pending de-duplication handlers list.
|
||||
*/
|
||||
interface GeneratedByExternalEvent {
|
||||
val deduplicationHandler: DeduplicationHandler
|
||||
}
|
||||
}
|
||||
|
@ -9,10 +9,15 @@ interface FlowHospital {
|
||||
/**
|
||||
* The flow running in [flowFiber] has errored.
|
||||
*/
|
||||
fun flowErrored(flowFiber: FlowFiber)
|
||||
fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>)
|
||||
|
||||
/**
|
||||
* The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume.
|
||||
*/
|
||||
fun flowCleaned(flowFiber: FlowFiber)
|
||||
|
||||
/**
|
||||
* The flow has been removed from the state machine.
|
||||
*/
|
||||
fun flowRemoved(flowFiber: FlowFiber)
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ interface FlowMessaging {
|
||||
* listen on the send acknowledgement.
|
||||
*/
|
||||
@Suspendable
|
||||
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId)
|
||||
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId)
|
||||
|
||||
/**
|
||||
* Start the messaging using the [onMessage] message handler.
|
||||
@ -49,7 +49,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) {
|
||||
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: SenderDeduplicationId) {
|
||||
log.trace { "Sending message $deduplicationId $message to party $party" }
|
||||
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
|
||||
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")
|
||||
|
@ -47,6 +47,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
|
||||
|
||||
private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
|
||||
|
||||
private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null)
|
||||
}
|
||||
|
||||
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
|
||||
@ -65,6 +67,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
internal var transientValues: TransientReference<TransientValues>? = null
|
||||
internal var transientState: TransientReference<StateMachineState>? = null
|
||||
|
||||
/**
|
||||
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
|
||||
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
|
||||
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
|
||||
*/
|
||||
override val ourSenderUUID: String?
|
||||
get() = transientState?.value?.senderUUID
|
||||
|
||||
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
|
||||
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
|
||||
return field.get(suppliedValues.value)
|
||||
@ -168,6 +178,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
fun setLoggingContext() {
|
||||
context.pushToLoggingContext()
|
||||
MDC.put("flow-id", id.uuid.toString())
|
||||
MDC.put("fiber-id", this.getId().toString())
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -185,7 +196,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true)
|
||||
Try.Success(result)
|
||||
} catch (throwable: Throwable) {
|
||||
logger.warn("Flow threw exception", throwable)
|
||||
logger.info("Flow threw exception... sending to flow hospital", throwable)
|
||||
Try.Failure<R>(throwable)
|
||||
}
|
||||
val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null
|
||||
@ -325,7 +336,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
isDbTransactionOpenOnExit = false
|
||||
)
|
||||
require(continuation == FlowContinuation.ProcessEvents)
|
||||
Fiber.unparkDeserialized(this, scheduler)
|
||||
unpark(SERIALIZER_BLOCKER)
|
||||
}
|
||||
setLoggingContext()
|
||||
return uncheckedCast(processEventsUntilFlowIsResumed(
|
||||
|
@ -9,12 +9,17 @@ import net.corda.core.utilities.loggerFor
|
||||
object PropagatingFlowHospital : FlowHospital {
|
||||
private val log = loggerFor<PropagatingFlowHospital>()
|
||||
|
||||
override fun flowErrored(flowFiber: FlowFiber) {
|
||||
log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" }
|
||||
override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>) {
|
||||
log.debug { "Flow ${flowFiber.id} in state $currentState encountered error" }
|
||||
flowFiber.scheduleEvent(Event.StartErrorPropagation)
|
||||
for ((index, error) in errors.withIndex()) {
|
||||
log.warn("Flow ${flowFiber.id} is propagating error [$index] ", error)
|
||||
}
|
||||
}
|
||||
|
||||
override fun flowCleaned(flowFiber: FlowFiber) {
|
||||
throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered")
|
||||
}
|
||||
|
||||
override fun flowRemoved(flowFiber: FlowFiber) {}
|
||||
}
|
||||
|
@ -46,13 +46,15 @@ import java.security.SecureRandom
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
import kotlin.collections.ArrayList
|
||||
import kotlin.concurrent.withLock
|
||||
import kotlin.streams.toList
|
||||
|
||||
/**
|
||||
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
|
||||
* thread actually starts them via [startFlow].
|
||||
* thread actually starts them via [deliverExternalEvent].
|
||||
*/
|
||||
@ThreadSafe
|
||||
class SingleThreadedStateMachineManager(
|
||||
@ -90,6 +92,7 @@ class SingleThreadedStateMachineManager(
|
||||
private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub)
|
||||
private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null
|
||||
private val transitionExecutor = makeTransitionExecutor()
|
||||
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
|
||||
|
||||
private var checkpointSerializationContext: SerializationContext? = null
|
||||
private var tokenizableServices: List<Any>? = null
|
||||
@ -128,7 +131,7 @@ class SingleThreadedStateMachineManager(
|
||||
resumeRestoredFlows(fibers)
|
||||
flowMessaging.start { receivedMessage, deduplicationHandler ->
|
||||
executor.execute {
|
||||
onSessionMessage(receivedMessage, deduplicationHandler)
|
||||
deliverExternalEvent(deduplicationHandler.externalCause)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -176,7 +179,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
override fun <A> startFlow(
|
||||
private fun <A> startFlow(
|
||||
flowLogic: FlowLogic<A>,
|
||||
context: InvocationContext,
|
||||
ourIdentity: Party?,
|
||||
@ -310,7 +313,73 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
|
||||
override fun retryFlowFromSafePoint(currentState: StateMachineState) {
|
||||
// Get set of external events
|
||||
val flowId = currentState.flowLogic.runId
|
||||
val oldFlowLeftOver = mutex.locked { flows[flowId] }?.fiber?.transientValues?.value?.eventQueue
|
||||
if (oldFlowLeftOver == null) {
|
||||
logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.")
|
||||
return
|
||||
}
|
||||
val flow = if (currentState.isAnyCheckpointPersisted) {
|
||||
val serializedCheckpoint = checkpointStorage.getCheckpoint(flowId)
|
||||
if (serializedCheckpoint == null) {
|
||||
logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
|
||||
return
|
||||
}
|
||||
val checkpoint = deserializeCheckpoint(serializedCheckpoint)
|
||||
if (checkpoint == null) {
|
||||
logger.error("Unable to deserialize database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
|
||||
return
|
||||
}
|
||||
// Resurrect flow
|
||||
createFlowFromCheckpoint(
|
||||
id = flowId,
|
||||
checkpoint = checkpoint,
|
||||
initialDeduplicationHandler = null,
|
||||
isAnyCheckpointPersisted = true,
|
||||
isStartIdempotent = false,
|
||||
senderUUID = null
|
||||
)
|
||||
} else {
|
||||
// Just flow initiation message
|
||||
null
|
||||
}
|
||||
externalEventMutex.withLock {
|
||||
if (flow != null) addAndStartFlow(flowId, flow)
|
||||
// Deliver all the external events from the old flow instance.
|
||||
val unprocessedExternalEvents = mutableListOf<ExternalEvent>()
|
||||
do {
|
||||
val event = oldFlowLeftOver.tryReceive()
|
||||
if (event is Event.GeneratedByExternalEvent) {
|
||||
unprocessedExternalEvents += event.deduplicationHandler.externalCause
|
||||
}
|
||||
} while (event != null)
|
||||
val externalEvents = currentState.pendingDeduplicationHandlers.map { it.externalCause } + unprocessedExternalEvents
|
||||
for (externalEvent in externalEvents) {
|
||||
deliverExternalEvent(externalEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val externalEventMutex = ReentrantLock()
|
||||
override fun deliverExternalEvent(event: ExternalEvent) {
|
||||
externalEventMutex.withLock {
|
||||
when (event) {
|
||||
is ExternalEvent.ExternalMessageEvent -> onSessionMessage(event)
|
||||
is ExternalEvent.ExternalStartFlowEvent<*> -> onExternalStartFlow(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T> onExternalStartFlow(event: ExternalEvent.ExternalStartFlowEvent<T>) {
|
||||
val future = startFlow(event.flowLogic, event.context, ourIdentity = null, deduplicationHandler = event.deduplicationHandler)
|
||||
event.wireUpFuture(future)
|
||||
}
|
||||
|
||||
private fun onSessionMessage(event: ExternalEvent.ExternalMessageEvent) {
|
||||
val message: ReceivedMessage = event.receivedMessage
|
||||
val deduplicationHandler: DeduplicationHandler = event.deduplicationHandler
|
||||
val peer = message.peer
|
||||
val sessionMessage = try {
|
||||
message.data.deserialize<SessionMessage>()
|
||||
@ -384,7 +453,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
|
||||
if (replyError != null) {
|
||||
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
|
||||
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID))
|
||||
deduplicationHandler.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
@ -458,7 +527,8 @@ class SingleThreadedStateMachineManager(
|
||||
isAnyCheckpointPersisted = false,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = flowLogic
|
||||
flowLogic = flowLogic,
|
||||
senderUUID = ourSenderUUID
|
||||
)
|
||||
flowStateMachineImpl.transientState = TransientReference(initialState)
|
||||
mutex.locked {
|
||||
@ -493,7 +563,7 @@ class SingleThreadedStateMachineManager(
|
||||
|
||||
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
|
||||
return FlowStateMachineImpl.TransientValues(
|
||||
eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK),
|
||||
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
|
||||
resultFuture = resultFuture,
|
||||
database = database,
|
||||
transitionExecutor = transitionExecutor,
|
||||
@ -509,7 +579,8 @@ class SingleThreadedStateMachineManager(
|
||||
checkpoint: Checkpoint,
|
||||
isAnyCheckpointPersisted: Boolean,
|
||||
isStartIdempotent: Boolean,
|
||||
initialDeduplicationHandler: DeduplicationHandler?
|
||||
initialDeduplicationHandler: DeduplicationHandler?,
|
||||
senderUUID: String? = ourSenderUUID
|
||||
): Flow {
|
||||
val flowState = checkpoint.flowState
|
||||
val resultFuture = openFuture<Any?>()
|
||||
@ -524,7 +595,8 @@ class SingleThreadedStateMachineManager(
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = logic
|
||||
flowLogic = logic,
|
||||
senderUUID = senderUUID
|
||||
)
|
||||
val fiber = FlowStateMachineImpl(id, logic, scheduler)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
@ -542,7 +614,8 @@ class SingleThreadedStateMachineManager(
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = fiber.logic
|
||||
flowLogic = fiber.logic,
|
||||
senderUUID = senderUUID
|
||||
)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
fiber.transientState = TransientReference(state)
|
||||
@ -566,9 +639,13 @@ class SingleThreadedStateMachineManager(
|
||||
startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping"))
|
||||
logger.trace("Not resuming as SMM is stopping.")
|
||||
} else {
|
||||
incrementLiveFibers()
|
||||
unfinishedFibers.countUp()
|
||||
flows[id] = flow
|
||||
val oldFlow = flows.put(id, flow)
|
||||
if (oldFlow == null) {
|
||||
incrementLiveFibers()
|
||||
unfinishedFibers.countUp()
|
||||
} else {
|
||||
oldFlow.resultFuture.captureLater(flow.resultFuture)
|
||||
}
|
||||
flow.fiber.scheduleEvent(Event.DoRemainingWork)
|
||||
when (checkpoint.flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
@ -604,7 +681,7 @@ class SingleThreadedStateMachineManager(
|
||||
|
||||
private fun makeTransitionExecutor(): TransitionExecutor {
|
||||
val interceptors = ArrayList<TransitionInterceptor>()
|
||||
interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) }
|
||||
interceptors.add { HospitalisingInterceptor(StaffedFlowHospital, it) }
|
||||
if (serviceHub.configuration.devMode) {
|
||||
interceptors.add { DumpHistoryOnErrorInterceptor(it) }
|
||||
}
|
||||
|
@ -0,0 +1,127 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import java.sql.SQLException
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
/**
|
||||
* This hospital consults "staff" to see if they can automatically diagnose and treat flows.
|
||||
*/
|
||||
object StaffedFlowHospital : FlowHospital {
|
||||
private val log = loggerFor<StaffedFlowHospital>()
|
||||
|
||||
private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist)
|
||||
|
||||
private val patients = ConcurrentHashMap<StateMachineRunId, MedicalHistory>()
|
||||
|
||||
val numberOfPatients = patients.size
|
||||
|
||||
class MedicalHistory {
|
||||
val records: MutableList<Record> = mutableListOf()
|
||||
|
||||
sealed class Record(val suspendCount: Int) {
|
||||
class Admitted(val at: Instant, suspendCount: Int) : Record(suspendCount) {
|
||||
override fun toString() = "Admitted(at=$at, suspendCount=$suspendCount)"
|
||||
}
|
||||
|
||||
class Discharged(val at: Instant, suspendCount: Int, val by: Staff, val error: Throwable) : Record(suspendCount) {
|
||||
override fun toString() = "Discharged(at=$at, suspendCount=$suspendCount, by=$by)"
|
||||
}
|
||||
}
|
||||
|
||||
fun notDischargedForTheSameThingMoreThan(max: Int, by: Staff): Boolean {
|
||||
val lastAdmittanceSuspendCount = (records.last() as MedicalHistory.Record.Admitted).suspendCount
|
||||
return records.filterIsInstance(MedicalHistory.Record.Discharged::class.java).filter { it.by == by && it.suspendCount == lastAdmittanceSuspendCount }.count() <= max
|
||||
}
|
||||
|
||||
override fun toString(): String = "${this.javaClass.simpleName}(records = $records)"
|
||||
}
|
||||
|
||||
override fun flowErrored(flowFiber: FlowFiber, currentState: StateMachineState, errors: List<Throwable>) {
|
||||
log.info("Flow ${flowFiber.id} admitted to hospital in state $currentState")
|
||||
val medicalHistory = patients.computeIfAbsent(flowFiber.id) { MedicalHistory() }
|
||||
medicalHistory.records += MedicalHistory.Record.Admitted(Instant.now(), currentState.checkpoint.numberOfSuspends)
|
||||
for ((index, error) in errors.withIndex()) {
|
||||
log.info("Flow ${flowFiber.id} has error [$index]", error)
|
||||
if (!errorIsDischarged(flowFiber, currentState, error, medicalHistory)) {
|
||||
// If any error isn't discharged, then we propagate.
|
||||
log.warn("Flow ${flowFiber.id} error was not discharged, propagating.")
|
||||
flowFiber.scheduleEvent(Event.StartErrorPropagation)
|
||||
return
|
||||
}
|
||||
}
|
||||
// If all are discharged, retry.
|
||||
flowFiber.scheduleEvent(Event.RetryFlowFromSafePoint)
|
||||
}
|
||||
|
||||
private fun errorIsDischarged(flowFiber: FlowFiber, currentState: StateMachineState, error: Throwable, medicalHistory: MedicalHistory): Boolean {
|
||||
for (staffMember in staff) {
|
||||
val diagnosis = staffMember.consult(flowFiber, currentState, error, medicalHistory)
|
||||
if (diagnosis == Diagnosis.DISCHARGE) {
|
||||
medicalHistory.records += MedicalHistory.Record.Discharged(Instant.now(), currentState.checkpoint.numberOfSuspends, staffMember, error)
|
||||
log.info("Flow ${flowFiber.id} error discharged from hospital by $staffMember")
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// It's okay for flows to be cleaned... we fix them now!
|
||||
override fun flowCleaned(flowFiber: FlowFiber) {}
|
||||
|
||||
override fun flowRemoved(flowFiber: FlowFiber) {
|
||||
patients.remove(flowFiber.id)
|
||||
}
|
||||
|
||||
enum class Diagnosis {
|
||||
/**
|
||||
* Retry from last safe point.
|
||||
*/
|
||||
DISCHARGE,
|
||||
/**
|
||||
* Please try another member of staff.
|
||||
*/
|
||||
NOT_MY_SPECIALTY
|
||||
}
|
||||
|
||||
interface Staff {
|
||||
fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis
|
||||
}
|
||||
|
||||
/**
|
||||
* SQL Deadlock detection.
|
||||
*/
|
||||
object DeadlockNurse : Staff {
|
||||
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis {
|
||||
return if (mentionsDeadlock(newError)) {
|
||||
Diagnosis.DISCHARGE
|
||||
} else {
|
||||
Diagnosis.NOT_MY_SPECIALTY
|
||||
}
|
||||
}
|
||||
|
||||
private fun mentionsDeadlock(exception: Throwable?): Boolean {
|
||||
return exception != null && (exception is SQLException && ((exception.message?.toLowerCase()?.contains("deadlock")
|
||||
?: false)) || mentionsDeadlock(exception.cause))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Primary key violation detection for duplicate inserts. Will detect other constraint violations too.
|
||||
*/
|
||||
object DuplicateInsertSpecialist : Staff {
|
||||
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis {
|
||||
return if (mentionsConstraintViolation(newError) && history.notDischargedForTheSameThingMoreThan(3, this)) {
|
||||
Diagnosis.DISCHARGE
|
||||
} else {
|
||||
Diagnosis.NOT_MY_SPECIALTY
|
||||
}
|
||||
}
|
||||
|
||||
private fun mentionsConstraintViolation(exception: Throwable?): Boolean {
|
||||
return exception != null && (exception is org.hibernate.exception.ConstraintViolationException || mentionsConstraintViolation(exception.cause))
|
||||
}
|
||||
}
|
||||
}
|
@ -4,11 +4,11 @@ import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import rx.Observable
|
||||
|
||||
/**
|
||||
@ -40,21 +40,6 @@ interface StateMachineManager {
|
||||
*/
|
||||
fun stop(allowedUnsuspendedFiberCount: Int)
|
||||
|
||||
/**
|
||||
* Starts a new flow.
|
||||
*
|
||||
* @param flowLogic The flow's code.
|
||||
* @param context The context of the flow.
|
||||
* @param ourIdentity The identity to use for the flow.
|
||||
* @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler].
|
||||
*/
|
||||
fun <A> startFlow(
|
||||
flowLogic: FlowLogic<A>,
|
||||
context: InvocationContext,
|
||||
ourIdentity: Party?,
|
||||
deduplicationHandler: DeduplicationHandler?
|
||||
): CordaFuture<FlowStateMachine<A>>
|
||||
|
||||
/**
|
||||
* Represents an addition/removal of a state machine.
|
||||
*/
|
||||
@ -91,6 +76,12 @@ interface StateMachineManager {
|
||||
* @return whether the flow existed and was killed.
|
||||
*/
|
||||
fun killFlow(id: StateMachineRunId): Boolean
|
||||
|
||||
/**
|
||||
* Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow.
|
||||
* The event may be replayed if a flow fails and attempts to retry.
|
||||
*/
|
||||
fun deliverExternalEvent(event: ExternalEvent)
|
||||
}
|
||||
|
||||
// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call
|
||||
@ -100,4 +91,38 @@ interface StateMachineManagerInternal {
|
||||
fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId)
|
||||
fun removeSessionBindings(sessionIds: Set<SessionId>)
|
||||
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
|
||||
fun retryFlowFromSafePoint(currentState: StateMachineState)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents an external event that can be injected into the state machine and that might need to be replayed if
|
||||
* a flow retries. They always have de-duplication handlers to assist with the at-most once logic where required.
|
||||
*/
|
||||
interface ExternalEvent {
|
||||
val deduplicationHandler: DeduplicationHandler
|
||||
|
||||
/**
|
||||
* An external P2P message event.
|
||||
*/
|
||||
interface ExternalMessageEvent : ExternalEvent {
|
||||
val receivedMessage: ReceivedMessage
|
||||
}
|
||||
|
||||
/**
|
||||
* An external request to start a flow, from the scheduler for example.
|
||||
*/
|
||||
interface ExternalStartFlowEvent<T> : ExternalEvent {
|
||||
val flowLogic: FlowLogic<T>
|
||||
val context: InvocationContext
|
||||
|
||||
/**
|
||||
* A callback for the state machine to pass back the [Future] associated with the flow start to the submitter.
|
||||
*/
|
||||
fun wireUpFuture(flowFuture: CordaFuture<FlowStateMachine<T>>)
|
||||
|
||||
/**
|
||||
* The future representing the flow start, passed back from the state machine to the submitter of this event.
|
||||
*/
|
||||
val future: CordaFuture<FlowStateMachine<T>>
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import net.corda.node.services.messaging.DeduplicationHandler
|
||||
* possible.
|
||||
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
|
||||
* work.
|
||||
* @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking.
|
||||
*/
|
||||
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
|
||||
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
|
||||
@ -37,7 +38,8 @@ data class StateMachineState(
|
||||
val isTransactionTracked: Boolean,
|
||||
val isAnyCheckpointPersisted: Boolean,
|
||||
val isStartIdempotent: Boolean,
|
||||
val isRemoved: Boolean
|
||||
val isRemoved: Boolean,
|
||||
val senderUUID: String?
|
||||
)
|
||||
|
||||
/**
|
||||
|
@ -34,7 +34,8 @@ class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : Transiti
|
||||
(record ?: ArrayList()).apply { add(transitionRecord) }
|
||||
}
|
||||
|
||||
if (nextState.checkpoint.errorState is ErrorState.Errored) {
|
||||
// Just if we decide to propagate, and not if just on the way to the hospital.
|
||||
if (nextState.checkpoint.errorState is ErrorState.Errored && nextState.checkpoint.errorState.propagating) {
|
||||
log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}")
|
||||
for (error in nextState.checkpoint.errorState.errors) {
|
||||
log.warn("Flow ${fiber.id} error", error.exception)
|
||||
|
23
node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt
23
node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt
@ -26,20 +26,23 @@ class HospitalisingInterceptor(
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
|
||||
when (nextState.checkpoint.errorState) {
|
||||
ErrorState.Clean -> {
|
||||
if (hospitalisedFlows.remove(fiber.id) != null) {
|
||||
flowHospital.flowCleaned(fiber)
|
||||
|
||||
when (nextState.checkpoint.errorState) {
|
||||
is ErrorState.Clean -> {
|
||||
if (hospitalisedFlows.remove(fiber.id) != null) {
|
||||
flowHospital.flowCleaned(fiber)
|
||||
}
|
||||
}
|
||||
is ErrorState.Errored -> {
|
||||
val exceptionsToHandle = nextState.checkpoint.errorState.errors.map { it.exception }
|
||||
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
|
||||
flowHospital.flowErrored(fiber, previousState, exceptionsToHandle)
|
||||
}
|
||||
}
|
||||
}
|
||||
is ErrorState.Errored -> {
|
||||
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
|
||||
flowHospital.flowErrored(fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nextState.isRemoved) {
|
||||
hospitalisedFlows.remove(fiber.id)
|
||||
flowHospital.flowRemoved(fiber)
|
||||
}
|
||||
return Pair(continuation, nextState)
|
||||
}
|
||||
|
@ -46,9 +46,6 @@ class DeliverSessionMessageTransition(
|
||||
is EndSessionMessage -> endMessageTransition()
|
||||
}
|
||||
}
|
||||
if (!isErrored()) {
|
||||
persistCheckpoint()
|
||||
}
|
||||
// Schedule a DoRemainingWork to check whether the flow needs to be woken up.
|
||||
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
|
||||
FlowContinuation.ProcessEvents
|
||||
@ -73,7 +70,7 @@ class DeliverSessionMessageTransition(
|
||||
// Send messages that were buffered pending confirmation of session.
|
||||
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
|
||||
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
|
||||
Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId)
|
||||
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
|
||||
}
|
||||
actions.addAll(sendActions)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
@ -146,21 +143,6 @@ class DeliverSessionMessageTransition(
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.persistCheckpoint() {
|
||||
// We persist the message as soon as it arrives.
|
||||
actions.addAll(arrayOf(
|
||||
Action.CreateTransaction,
|
||||
Action.PersistCheckpoint(context.id, currentState.checkpoint),
|
||||
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
|
||||
))
|
||||
currentState = currentState.copy(
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isAnyCheckpointPersisted = true
|
||||
)
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.endMessageTransition() {
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val sessions = currentState.checkpoint.sessions
|
||||
|
@ -46,7 +46,7 @@ class ErrorFlowTransition(
|
||||
sessions = newSessions
|
||||
)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions))
|
||||
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID))
|
||||
}
|
||||
|
||||
// If we're errored but not propagating keep processing events.
|
||||
|
@ -216,7 +216,7 @@ class StartedFlowTransition(
|
||||
}
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
|
||||
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null)
|
||||
actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId))
|
||||
actions.add(Action.SendInitial(sessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
|
||||
newSessions[sourceSessionId] = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null
|
||||
@ -253,7 +253,7 @@ class StartedFlowTransition(
|
||||
when (existingSessionState) {
|
||||
is SessionState.Uninitiated -> {
|
||||
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message)
|
||||
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId))
|
||||
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
|
||||
newSessions[sourceSessionId] = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null
|
||||
@ -270,7 +270,7 @@ class StartedFlowTransition(
|
||||
is InitiatedSessionState.Live -> {
|
||||
val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
|
||||
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId))
|
||||
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
|
||||
Unit
|
||||
}
|
||||
InitiatedSessionState.Ended -> {
|
||||
|
@ -18,18 +18,19 @@ class TopLevelTransition(
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
return when (event) {
|
||||
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
|
||||
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
|
||||
is Event.Error -> errorTransition(event)
|
||||
is Event.TransactionCommitted -> transactionCommittedTransition(event)
|
||||
is Event.SoftShutdown -> softShutdownTransition()
|
||||
is Event.StartErrorPropagation -> startErrorPropagationTransition()
|
||||
is Event.EnterSubFlow -> enterSubFlowTransition(event)
|
||||
is Event.LeaveSubFlow -> leaveSubFlowTransition()
|
||||
is Event.Suspend -> suspendTransition(event)
|
||||
is Event.FlowFinish -> flowFinishTransition(event)
|
||||
is Event.InitiateFlow -> initiateFlowTransition(event)
|
||||
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
|
||||
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
|
||||
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
|
||||
is Event.Error -> errorTransition(event)
|
||||
is Event.TransactionCommitted -> transactionCommittedTransition(event)
|
||||
is Event.SoftShutdown -> softShutdownTransition()
|
||||
is Event.StartErrorPropagation -> startErrorPropagationTransition()
|
||||
is Event.EnterSubFlow -> enterSubFlowTransition(event)
|
||||
is Event.LeaveSubFlow -> leaveSubFlowTransition()
|
||||
is Event.Suspend -> suspendTransition(event)
|
||||
is Event.FlowFinish -> flowFinishTransition(event)
|
||||
is Event.InitiateFlow -> initiateFlowTransition(event)
|
||||
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
|
||||
is Event.RetryFlowFromSafePoint -> retryFlowFromSafePointTransition(startingState)
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,7 +192,7 @@ class TopLevelTransition(
|
||||
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
|
||||
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
|
||||
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index)
|
||||
Action.SendExisting(state.peerParty, message, deduplicationId)
|
||||
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
@ -230,4 +231,14 @@ class TopLevelTransition(
|
||||
resumeFlowLogic(event.returnValue)
|
||||
}
|
||||
}
|
||||
|
||||
private fun retryFlowFromSafePointTransition(startingState: StateMachineState): TransitionResult {
|
||||
return builder {
|
||||
// Need to create a flow from the prior checkpoint or flow initiation.
|
||||
actions.add(Action.CreateTransaction)
|
||||
actions.add(Action.RetryFlowFromSafePoint(startingState))
|
||||
actions.add(Action.CommitTransaction)
|
||||
FlowContinuation.Abort
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ class UnstartedFlowTransition(
|
||||
Action.SendExisting(
|
||||
flowStart.peerSession.counterparty,
|
||||
sessionMessage,
|
||||
DeduplicationId.createForNormal(currentState.checkpoint, 0)
|
||||
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0), currentState.senderUUID)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
@ -3,14 +3,21 @@ package net.corda.node.utilities
|
||||
import com.github.benmanes.caffeine.cache.LoadingCache
|
||||
import com.github.benmanes.caffeine.cache.Weigher
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import net.corda.nodeapi.internal.persistence.currentDBSession
|
||||
import java.lang.ref.WeakReference
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
|
||||
/**
|
||||
* Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice the
|
||||
* behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so
|
||||
* ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY
|
||||
* Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [set] twice,
|
||||
* typically this will result in a duplicate insert if this is racing with another transaction. The flow framework will then retry.
|
||||
*
|
||||
* This class relies heavily on the fact that compute operations in the cache are atomic for a particular key.
|
||||
*/
|
||||
abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
val toPersistentEntityKey: (K) -> EK,
|
||||
@ -23,7 +30,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
protected abstract val cache: LoadingCache<K, Optional<V>>
|
||||
protected abstract val cache: LoadingCache<K, Transactional<V>>
|
||||
protected val pendingKeys = ConcurrentHashMap<K, MutableSet<DatabaseTransaction>>()
|
||||
|
||||
/**
|
||||
* Returns the value associated with the key, first loading that value from the storage if necessary.
|
||||
@ -47,32 +55,31 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
return result.map { x -> fromPersistentEntity(x) }.asSequence()
|
||||
}
|
||||
|
||||
private tailrec fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean {
|
||||
var insertionAttempt = false
|
||||
var isUnique = true
|
||||
val existingInCache = cache.get(key) {
|
||||
// Thread safe, if multiple threads may wait until the first one has loaded.
|
||||
insertionAttempt = true
|
||||
// Key wasn't in the cache and might be in the underlying storage.
|
||||
// Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key.
|
||||
val existingInDb = store(key, value)
|
||||
if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key.
|
||||
isUnique = false
|
||||
Optional.of(existingInDb)
|
||||
} else {
|
||||
Optional.of(value)
|
||||
}
|
||||
}!!
|
||||
if (!insertionAttempt) {
|
||||
if (existingInCache.isPresent) {
|
||||
// Key already exists in cache, do nothing.
|
||||
isUnique = false
|
||||
} else {
|
||||
// This happens when the key was queried before with no value associated. We invalidate the cached null
|
||||
// value and recursively call set again. This is to avoid race conditions where another thread queries after
|
||||
// the invalidate but before the set.
|
||||
cache.invalidate(key!!)
|
||||
return set(key, value, logWarning, store)
|
||||
private fun set(key: K, value: V, logWarning: Boolean, store: (K, V) -> V?): Boolean {
|
||||
// Will be set to true if store says it isn't in the database.
|
||||
var isUnique = false
|
||||
cache.asMap().compute(key) { _, oldValue ->
|
||||
// Always write to the database, unless we can see it's already committed.
|
||||
when (oldValue) {
|
||||
is Transactional.InFlight<*, V> -> {
|
||||
// Someone else is writing, so store away!
|
||||
// TODO: we can do collision detection here and prevent it happening in the database. But we also have to do deadlock detection, so a bit of work.
|
||||
isUnique = (store(key, value) == null)
|
||||
oldValue.apply { alsoWrite(value) }
|
||||
}
|
||||
is Transactional.Committed<V> -> oldValue // The value is already globally visible and cached. So do nothing since the values are always the same.
|
||||
else -> {
|
||||
// Null or Missing. Store away!
|
||||
isUnique = (store(key, value) == null)
|
||||
if (!isUnique && !weAreWriting(key)) {
|
||||
// If we found a value already in the database, and we were not already writing, then it's already committed but got evicted.
|
||||
Transactional.Committed(value)
|
||||
} else {
|
||||
// Some database transactions, including us, writing, with readers seeing whatever is in the database and writers seeing the (in memory) value.
|
||||
Transactional.InFlight(this, key, { loadValue(key) }).apply { alsoWrite(value) }
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if (logWarning && !isUnique) {
|
||||
@ -93,7 +100,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
|
||||
/**
|
||||
* Associates the specified value with the specified key in this map and persists it.
|
||||
* If the map previously contained a mapping for the key, the old value is not replaced.
|
||||
* If the map previously contained a committed mapping for the key, the old value is not replaced. It may throw an error from the
|
||||
* underlying storage if this races with another database transaction to store a value for the same key.
|
||||
* @return true if added key was unique, otherwise false
|
||||
*/
|
||||
fun addWithDuplicatesAllowed(key: K, value: V, logWarning: Boolean = true): Boolean =
|
||||
@ -116,7 +124,7 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
|
||||
protected fun loadValue(key: K): V? {
|
||||
val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key))
|
||||
return result?.let(fromPersistentEntity)?.second
|
||||
return result?.apply { currentDBSession().detach(result) }?.let(fromPersistentEntity)?.second
|
||||
}
|
||||
|
||||
operator fun contains(key: K) = get(key) != null
|
||||
@ -132,9 +140,161 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
|
||||
session.createQuery(deleteQuery).executeUpdate()
|
||||
cache.invalidateAll()
|
||||
}
|
||||
|
||||
// Helpers to know if transaction(s) are currently writing the given key.
|
||||
protected fun weAreWriting(key: K): Boolean = pendingKeys.get(key)?.contains(contextTransaction) ?: false
|
||||
protected fun anyoneWriting(key: K): Boolean = pendingKeys.get(key)?.isNotEmpty() ?: false
|
||||
|
||||
// Indicate this database transaction is a writer of this key.
|
||||
private fun addPendingKey(key: K, databaseTransaction: DatabaseTransaction): Boolean {
|
||||
var added = true
|
||||
pendingKeys.compute(key) { k, oldSet ->
|
||||
if (oldSet == null) {
|
||||
val newSet = HashSet<DatabaseTransaction>(0)
|
||||
newSet += databaseTransaction
|
||||
newSet
|
||||
} else {
|
||||
added = oldSet.add(databaseTransaction)
|
||||
oldSet
|
||||
}
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
// Remove this database transaction as a writer of this key, because the transaction committed or rolled back.
|
||||
private fun removePendingKey(key: K, databaseTransaction: DatabaseTransaction) {
|
||||
pendingKeys.compute(key) { k, oldSet ->
|
||||
if (oldSet == null) {
|
||||
oldSet
|
||||
} else {
|
||||
oldSet -= databaseTransaction
|
||||
if (oldSet.size == 0) null else oldSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a value in the cache, with transaction isolation semantics.
|
||||
*
|
||||
* There are 3 states. Globally missing, globally visible, and being written in a transaction somewhere now or in
|
||||
* the past (and it rolled back).
|
||||
*/
|
||||
sealed class Transactional<T> {
|
||||
abstract val value: T
|
||||
abstract val isPresent: Boolean
|
||||
abstract val valueWithoutIsolation: T?
|
||||
|
||||
fun orElse(alt: T?) = if (isPresent) value else alt
|
||||
|
||||
// Everyone can see it, and database transaction committed.
|
||||
class Committed<T>(override val value: T) : Transactional<T>() {
|
||||
override val isPresent: Boolean
|
||||
get() = true
|
||||
override val valueWithoutIsolation: T?
|
||||
get() = value
|
||||
}
|
||||
|
||||
// No one can see it.
|
||||
class Missing<T>() : Transactional<T>() {
|
||||
override val value: T
|
||||
get() = throw NoSuchElementException("Not present")
|
||||
override val isPresent: Boolean
|
||||
get() = false
|
||||
override val valueWithoutIsolation: T?
|
||||
get() = null
|
||||
}
|
||||
|
||||
// Written in a transaction (uncommitted) somewhere, but there's a small window when this might be seen after commit,
|
||||
// hence the committed flag.
|
||||
class InFlight<K, T>(private val map: AppendOnlyPersistentMapBase<K, T, *, *>,
|
||||
private val key: K,
|
||||
private val _readerValueLoader: () -> T?,
|
||||
private val _writerValueLoader: () -> T = { throw IllegalAccessException("No value loader provided") }) : Transactional<T>() {
|
||||
|
||||
// A flag to indicate this has now been committed, but hasn't yet been replaced with Committed. This also
|
||||
// de-duplicates writes of the Committed value to the cache.
|
||||
private val committed = AtomicBoolean(false)
|
||||
|
||||
// What to do if a non-writer needs to see the value and it hasn't yet been committed to the database.
|
||||
// Can be updated into a no-op once evaluated.
|
||||
private val readerValueLoader = AtomicReference<() -> T?>(_readerValueLoader)
|
||||
// What to do if a writer needs to see the value and it hasn't yet been committed to the database.
|
||||
// Can be updated into a no-op once evaluated.
|
||||
private val writerValueLoader = AtomicReference<() -> T>(_writerValueLoader)
|
||||
|
||||
fun alsoWrite(_value: T) {
|
||||
// Make the lazy loader the writers see actually just return the value that has been set.
|
||||
writerValueLoader.set({ _value })
|
||||
// We make all these vals so that the lambdas do not need a reference to this, and so the onCommit only has a weak ref to the value.
|
||||
// We want this so that the cache could evict the value (due to memory constraints etc) without the onCommit callback
|
||||
// retaining what could be a large memory footprint object.
|
||||
val tx = contextTransaction
|
||||
val strongKey = key
|
||||
val weakValue = WeakReference<T>(_value)
|
||||
val strongComitted = committed
|
||||
val strongMap = map
|
||||
if (map.addPendingKey(key, tx)) {
|
||||
// If the transaction commits, update cache to make globally visible if we're first for this key,
|
||||
// and then stop saying the transaction is writing the key.
|
||||
tx.onCommit {
|
||||
if (strongComitted.compareAndSet(false, true)) {
|
||||
val dereferencedKey = strongKey
|
||||
val dereferencedValue = weakValue.get()
|
||||
if (dereferencedValue != null) {
|
||||
strongMap.cache.put(dereferencedKey, Committed(dereferencedValue))
|
||||
}
|
||||
}
|
||||
strongMap.removePendingKey(strongKey, tx)
|
||||
}
|
||||
// If the transaction rolls back, stop saying this transaction is writing the key.
|
||||
tx.onRollback {
|
||||
strongMap.removePendingKey(strongKey, tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lazy load the value a "writer" would see. If the original loader hasn't been replaced, replace it
|
||||
// with one that just returns the value once evaluated.
|
||||
private fun loadAsWriter(): T {
|
||||
val _value = writerValueLoader.get()()
|
||||
if (writerValueLoader.get() == _writerValueLoader) {
|
||||
writerValueLoader.set({ _value })
|
||||
}
|
||||
return _value
|
||||
}
|
||||
|
||||
// Lazy load the value a "reader" would see. If the original loader hasn't been replaced, replace it
|
||||
// with one that just returns the value once evaluated.
|
||||
private fun loadAsReader(): T? {
|
||||
val _value = readerValueLoader.get()()
|
||||
if (readerValueLoader.get() == _readerValueLoader) {
|
||||
readerValueLoader.set({ _value })
|
||||
}
|
||||
return _value
|
||||
}
|
||||
|
||||
// Whether someone reading (only) can see the entry.
|
||||
private val isPresentAsReader: Boolean get() = (loadAsReader() != null)
|
||||
// Whether the entry is already written and committed, or we are writing it (and thus it can be seen).
|
||||
private val isPresentAsWriter: Boolean get() = committed.get() || map.weAreWriting(key)
|
||||
|
||||
override val isPresent: Boolean
|
||||
get() = isPresentAsWriter || isPresentAsReader
|
||||
|
||||
// If it is committed or we are writing, reveal the value, potentially lazy loading from the database.
|
||||
// If none of the above, see what was already in the database, potentially lazily.
|
||||
override val value: T
|
||||
get() = if (isPresentAsWriter) loadAsWriter() else if (isPresentAsReader) loadAsReader()!! else throw NoSuchElementException("Not present")
|
||||
|
||||
// The value from the perspective of the eviction algorithm of the cache. i.e. we want to reveal memory footprint to it etc.
|
||||
override val valueWithoutIsolation: T?
|
||||
get() = if (writerValueLoader.get() != _writerValueLoader) writerValueLoader.get()() else if (readerValueLoader.get() != _writerValueLoader) readerValueLoader.get()() else null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AppendOnlyPersistentMap<K, V, E, out EK>(
|
||||
// Open for tests to override
|
||||
open class AppendOnlyPersistentMap<K, V, E, out EK>(
|
||||
toPersistentEntityKey: (K) -> EK,
|
||||
fromPersistentEntity: (E) -> Pair<K, V>,
|
||||
toPersistentEntity: (key: K, value: V) -> E,
|
||||
@ -146,26 +306,71 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
|
||||
toPersistentEntity,
|
||||
persistentEntityClass) {
|
||||
//TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size
|
||||
override val cache = NonInvalidatingCache<K, Optional<V>>(
|
||||
override val cache = NonInvalidatingCache<K, Transactional<V>>(
|
||||
bound = cacheBound,
|
||||
loadFunction = { key -> Optional.ofNullable(loadValue(key)) })
|
||||
loadFunction = { key: K ->
|
||||
// This gets called if a value is read and the cache has no Transactional for this key yet.
|
||||
val value: V? = loadValue(key)
|
||||
if (value == null) {
|
||||
// No visible value
|
||||
if (anyoneWriting(key)) {
|
||||
// If someone is writing (but not us)
|
||||
// For those not writing, the value cannot be seen.
|
||||
// For those writing, they need to re-load the value from the database (which their database transaction CAN see).
|
||||
Transactional.InFlight<K, V>(this, key, { null }, { loadValue(key)!! })
|
||||
} else {
|
||||
// If no one is writing, then the value does not exist.
|
||||
Transactional.Missing<V>()
|
||||
}
|
||||
} else {
|
||||
// A value was found
|
||||
if (weAreWriting(key)) {
|
||||
// If we are writing, it might not be globally visible, and was evicted from the cache.
|
||||
// For those not writing, they need to check the database again.
|
||||
// For those writing, they can see the value found.
|
||||
Transactional.InFlight<K, V>(this, key, { loadValue(key) }, { value })
|
||||
} else {
|
||||
// If no one is writing, then make it globally visible.
|
||||
Transactional.Committed<V>(value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Same as above, but with weighted values (e.g. memory footprint sensitive).
|
||||
class WeightBasedAppendOnlyPersistentMap<K, V, E, out EK>(
|
||||
toPersistentEntityKey: (K) -> EK,
|
||||
fromPersistentEntity: (E) -> Pair<K, V>,
|
||||
toPersistentEntity: (key: K, value: V) -> E,
|
||||
persistentEntityClass: Class<E>,
|
||||
maxWeight: Long,
|
||||
weighingFunc: (K, Optional<V>) -> Int
|
||||
weighingFunc: (K, Transactional<V>) -> Int
|
||||
) : AppendOnlyPersistentMapBase<K, V, E, EK>(
|
||||
toPersistentEntityKey,
|
||||
fromPersistentEntity,
|
||||
toPersistentEntity,
|
||||
persistentEntityClass) {
|
||||
override val cache = NonInvalidatingWeightBasedCache(
|
||||
override val cache = NonInvalidatingWeightBasedCache<K, Transactional<V>>(
|
||||
maxWeight = maxWeight,
|
||||
weigher = Weigher<K, Optional<V>> { key, value -> weighingFunc(key, value) },
|
||||
loadFunction = { key -> Optional.ofNullable(loadValue(key)) }
|
||||
)
|
||||
}
|
||||
weigher = object : Weigher<K, Transactional<V>> {
|
||||
override fun weigh(key: K, value: Transactional<V>): Int {
|
||||
return weighingFunc(key, value)
|
||||
}
|
||||
},
|
||||
loadFunction = { key: K ->
|
||||
val value: V? = loadValue(key)
|
||||
if (value == null) {
|
||||
if (anyoneWriting(key)) {
|
||||
Transactional.InFlight<K, V>(this, key, { null }, { loadValue(key)!! })
|
||||
} else {
|
||||
Transactional.Missing<V>()
|
||||
}
|
||||
} else {
|
||||
if (weAreWriting(key)) {
|
||||
Transactional.InFlight<K, V>(this, key, { loadValue(key) }, { value })
|
||||
} else {
|
||||
Transactional.Committed<V>(value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -14,14 +14,15 @@ import net.corda.node.internal.configureDatabase
|
||||
import net.corda.node.services.api.FlowStarter
|
||||
import net.corda.node.services.api.NodePropertiesStore
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
|
||||
import net.corda.testing.internal.doLookup
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.internal.spectator
|
||||
import net.corda.testing.node.MockServices
|
||||
import net.corda.testing.node.TestClock
|
||||
import org.junit.After
|
||||
import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
@ -50,7 +51,7 @@ open class NodeSchedulerServiceTestBase {
|
||||
dedupe.insideDatabaseTransaction()
|
||||
dedupe.afterDatabaseTransaction()
|
||||
openFuture<FlowStateMachine<*>>()
|
||||
}.whenever(it).startFlow(any<FlowLogic<*>>(), any(), any())
|
||||
}.whenever(it).startFlow(any<ExternalEvent.ExternalStartFlowEvent<*>>())
|
||||
}
|
||||
private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also {
|
||||
doReturn(false).whenever(it).isEnabled()
|
||||
@ -80,7 +81,7 @@ open class NodeSchedulerServiceTestBase {
|
||||
|
||||
protected fun assertStarted(flowLogic: FlowLogic<*>) {
|
||||
// Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow:
|
||||
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic), any(), any())
|
||||
verify(flowStarter, timeout(5000)).startFlow(argForWhich<ExternalEvent.ExternalStartFlowEvent<*>> { this.flowLogic == flowLogic })
|
||||
}
|
||||
|
||||
protected fun assertStarted(event: Event) = assertStarted(event.flowLogic)
|
||||
@ -112,11 +113,11 @@ class MockScheduledFlowRepository : ScheduledFlowRepository {
|
||||
}
|
||||
|
||||
class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
private val database = rigorousMock<CordaPersistence>().also {
|
||||
doAnswer {
|
||||
val block: DatabaseTransaction.() -> Any? = it.getArgument(0)
|
||||
rigorousMock<DatabaseTransaction>().block()
|
||||
}.whenever(it).transaction(any())
|
||||
private val database = configureDatabase(MockServices.makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock())
|
||||
|
||||
@After
|
||||
fun closeDatabase() {
|
||||
database.close()
|
||||
}
|
||||
|
||||
private val scheduler = NodeSchedulerService(
|
||||
@ -148,7 +149,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
}).whenever(it).data
|
||||
}
|
||||
flows[logicRef] = flowLogic
|
||||
scheduler.scheduleStateActivity(ssr)
|
||||
database.transaction {
|
||||
scheduler.scheduleStateActivity(ssr)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -207,7 +210,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
fun `test activity due in the future and schedule another for same time then unschedule second`() {
|
||||
val eventA = schedule(mark + 1.days)
|
||||
val eventB = schedule(mark + 1.days)
|
||||
scheduler.unscheduleStateActivity(eventB.stateRef)
|
||||
database.transaction {
|
||||
scheduler.unscheduleStateActivity(eventB.stateRef)
|
||||
}
|
||||
assertWaitingFor(eventA)
|
||||
testClock.advanceBy(1.days)
|
||||
assertStarted(eventA)
|
||||
@ -217,7 +222,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
fun `test activity due in the future and schedule another for same time then unschedule original`() {
|
||||
val eventA = schedule(mark + 1.days)
|
||||
val eventB = schedule(mark + 1.days)
|
||||
scheduler.unscheduleStateActivity(eventA.stateRef)
|
||||
database.transaction {
|
||||
scheduler.unscheduleStateActivity(eventA.stateRef)
|
||||
}
|
||||
assertWaitingFor(eventB)
|
||||
testClock.advanceBy(1.days)
|
||||
assertStarted(eventB)
|
||||
@ -225,7 +232,9 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
|
||||
@Test
|
||||
fun `test activity due in the future then unschedule`() {
|
||||
scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef)
|
||||
database.transaction {
|
||||
scheduler.unscheduleStateActivity(schedule(mark + 1.days).stateRef)
|
||||
}
|
||||
testClock.advanceBy(1.days)
|
||||
}
|
||||
}
|
||||
|
290
node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt
Normal file
290
node/src/test/kotlin/net/corda/node/services/persistence/AppendOnlyPersistentMapTest.kt
Normal file
@ -0,0 +1,290 @@
|
||||
package net.corda.node.services.persistence
|
||||
|
||||
import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.node.internal.configureDatabase
|
||||
import net.corda.node.services.schema.NodeSchemaService
|
||||
import net.corda.node.utilities.AppendOnlyPersistentMap
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import org.junit.After
|
||||
import org.junit.Assert.*
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import java.io.Serializable
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Id
|
||||
import javax.persistence.PersistenceException
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
class AppendOnlyPersistentMapTest(var scenario: Scenario) {
|
||||
companion object {
|
||||
|
||||
private val scenarios = arrayOf<Scenario>(
|
||||
Scenario(false, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Fail, Outcome.Fail),
|
||||
Scenario(false, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success),
|
||||
Scenario(false, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Fail, Outcome.Success),
|
||||
Scenario(false, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit),
|
||||
Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Success, Outcome.Fail, Outcome.Success),
|
||||
Scenario(false, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Success),
|
||||
Scenario(false, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.SuccessButErrorOnCommit, Outcome.Fail),
|
||||
Scenario(true, ReadOrWrite.Read, ReadOrWrite.Read, Outcome.Success, Outcome.Success),
|
||||
Scenario(true, ReadOrWrite.Write, ReadOrWrite.Read, Outcome.SuccessButErrorOnCommit, Outcome.Success),
|
||||
Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail),
|
||||
Scenario(true, ReadOrWrite.Write, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit),
|
||||
Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.Read, Outcome.Fail, Outcome.Success),
|
||||
Scenario(true, ReadOrWrite.Read, ReadOrWrite.WriteDuplicateAllowed, Outcome.Success, Outcome.Fail),
|
||||
Scenario(true, ReadOrWrite.WriteDuplicateAllowed, ReadOrWrite.WriteDuplicateAllowed, Outcome.Fail, Outcome.Fail)
|
||||
)
|
||||
|
||||
@Parameterized.Parameters(name = "{0}")
|
||||
@JvmStatic
|
||||
fun data(): Array<Array<Scenario>> = scenarios.map { arrayOf(it) }.toTypedArray()
|
||||
}
|
||||
|
||||
enum class ReadOrWrite { Read, Write, WriteDuplicateAllowed }
|
||||
enum class Outcome { Success, Fail, SuccessButErrorOnCommit }
|
||||
|
||||
data class Scenario(val prePopulated: Boolean,
|
||||
val a: ReadOrWrite,
|
||||
val b: ReadOrWrite,
|
||||
val aExpected: Outcome,
|
||||
val bExpected: Outcome,
|
||||
val bExpectedIfSingleThreaded: Outcome = bExpected)
|
||||
|
||||
private val database = configureDatabase(makeTestDataSourceProperties(),
|
||||
DatabaseConfig(),
|
||||
rigorousMock(),
|
||||
NodeSchemaService(setOf(MappedSchema(AppendOnlyPersistentMapTest::class.java, 1, listOf(PersistentMapEntry::class.java)))))
|
||||
|
||||
@After
|
||||
fun closeDatabase() {
|
||||
database.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `concurrent test no purge between A and B`() {
|
||||
prepopulateIfRequired()
|
||||
val map = createMap()
|
||||
val a = TestThread("A", map).apply { start() }
|
||||
val b = TestThread("B", map).apply { start() }
|
||||
|
||||
// Begin A
|
||||
a.phase1.countDown()
|
||||
a.await(a::phase2)
|
||||
|
||||
// Begin B
|
||||
b.phase1.countDown()
|
||||
b.await(b::phase2)
|
||||
|
||||
// Commit A
|
||||
a.phase3.countDown()
|
||||
a.await(a::phase4)
|
||||
|
||||
// Commit B
|
||||
b.phase3.countDown()
|
||||
b.await(b::phase4)
|
||||
|
||||
// End
|
||||
a.join()
|
||||
b.join()
|
||||
assertTrue(map.pendingKeysIsEmpty())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test no purge with only a single transaction`() {
|
||||
prepopulateIfRequired()
|
||||
val map = createMap()
|
||||
val a = TestThread("A", map, true).apply {
|
||||
phase1.countDown()
|
||||
phase3.countDown()
|
||||
}
|
||||
val b = TestThread("B", map, true).apply {
|
||||
phase1.countDown()
|
||||
phase3.countDown()
|
||||
}
|
||||
try {
|
||||
database.transaction {
|
||||
a.run()
|
||||
b.run()
|
||||
}
|
||||
} catch (t: PersistenceException) {
|
||||
// This only helps if thrown on commit, otherwise other latches not counted down.
|
||||
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome)
|
||||
}
|
||||
a.await(a::phase4)
|
||||
b.await(b::phase4)
|
||||
assertTrue(map.pendingKeysIsEmpty())
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `concurrent test purge between A and B`() {
|
||||
// Writes intentionally do not check the database first, so purging between read and write changes behaviour
|
||||
val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.SuccessButErrorOnCommit))
|
||||
scenario = remapped[scenario] ?: scenario
|
||||
prepopulateIfRequired()
|
||||
val map = createMap()
|
||||
val a = TestThread("A", map).apply { start() }
|
||||
val b = TestThread("B", map).apply { start() }
|
||||
|
||||
// Begin A
|
||||
a.phase1.countDown()
|
||||
a.await(a::phase2)
|
||||
|
||||
map.invalidate()
|
||||
|
||||
// Begin B
|
||||
b.phase1.countDown()
|
||||
b.await(b::phase2)
|
||||
|
||||
// Commit A
|
||||
a.phase3.countDown()
|
||||
a.await(a::phase4)
|
||||
|
||||
// Commit B
|
||||
b.phase3.countDown()
|
||||
b.await(b::phase4)
|
||||
|
||||
// End
|
||||
a.join()
|
||||
b.join()
|
||||
assertTrue(map.pendingKeysIsEmpty())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test purge mid-way in a single transaction`() {
|
||||
// Writes intentionally do not check the database first, so purging between read and write changes behaviour
|
||||
val remapped = mapOf(Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.Success, Outcome.Fail) to Scenario(true, ReadOrWrite.Read, ReadOrWrite.Write, Outcome.SuccessButErrorOnCommit, Outcome.SuccessButErrorOnCommit))
|
||||
scenario = remapped[scenario] ?: scenario
|
||||
prepopulateIfRequired()
|
||||
val map = createMap()
|
||||
val a = TestThread("A", map, true).apply {
|
||||
phase1.countDown()
|
||||
phase3.countDown()
|
||||
}
|
||||
val b = TestThread("B", map, true).apply {
|
||||
phase1.countDown()
|
||||
phase3.countDown()
|
||||
}
|
||||
try {
|
||||
database.transaction {
|
||||
a.run()
|
||||
map.invalidate()
|
||||
b.run()
|
||||
}
|
||||
} catch (t: PersistenceException) {
|
||||
// This only helps if thrown on commit, otherwise other latches not counted down.
|
||||
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, a.outcome)
|
||||
}
|
||||
a.await(a::phase4)
|
||||
b.await(b::phase4)
|
||||
assertTrue(map.pendingKeysIsEmpty())
|
||||
}
|
||||
|
||||
inner class TestThread(name: String, val map: AppendOnlyPersistentMap<Long, String, PersistentMapEntry, Long>, singleThreaded: Boolean = false) : Thread(name) {
|
||||
private val log = loggerFor<TestThread>()
|
||||
|
||||
val readOrWrite = if (name == "A") scenario.a else scenario.b
|
||||
val outcome = if (name == "A") scenario.aExpected else if (singleThreaded) scenario.bExpectedIfSingleThreaded else scenario.bExpected
|
||||
|
||||
val phase1 = latch()
|
||||
val phase2 = latch()
|
||||
val phase3 = latch()
|
||||
val phase4 = latch()
|
||||
|
||||
override fun run() {
|
||||
try {
|
||||
database.transaction {
|
||||
await(::phase1)
|
||||
doActivity()
|
||||
phase2.countDown()
|
||||
await(::phase3)
|
||||
}
|
||||
} catch (t: PersistenceException) {
|
||||
// This only helps if thrown on commit, otherwise other latches not counted down.
|
||||
assertEquals(t.message, Outcome.SuccessButErrorOnCommit, outcome)
|
||||
}
|
||||
phase4.countDown()
|
||||
}
|
||||
|
||||
private fun doActivity() {
|
||||
if (readOrWrite == ReadOrWrite.Read) {
|
||||
log.info("Reading")
|
||||
val value = map.get(1)
|
||||
log.info("Read $value")
|
||||
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
|
||||
assertEquals("X", value)
|
||||
} else {
|
||||
assertNull(value)
|
||||
}
|
||||
} else if (readOrWrite == ReadOrWrite.Write) {
|
||||
log.info("Writing")
|
||||
val wasSet = map.set(1, "X")
|
||||
log.info("Write $wasSet")
|
||||
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
|
||||
assertEquals(true, wasSet)
|
||||
} else {
|
||||
assertEquals(false, wasSet)
|
||||
}
|
||||
} else if (readOrWrite == ReadOrWrite.WriteDuplicateAllowed) {
|
||||
log.info("Writing with duplicates allowed")
|
||||
val wasSet = map.addWithDuplicatesAllowed(1, "X")
|
||||
log.info("Write with duplicates allowed $wasSet")
|
||||
if (outcome == Outcome.Success || outcome == Outcome.SuccessButErrorOnCommit) {
|
||||
assertEquals(true, wasSet)
|
||||
} else {
|
||||
assertEquals(false, wasSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun latch() = CountDownLatch(1)
|
||||
fun await(latch: () -> CountDownLatch) {
|
||||
log.info("Awaiting $latch")
|
||||
latch().await()
|
||||
}
|
||||
}
|
||||
|
||||
private fun prepopulateIfRequired() {
|
||||
if (scenario.prePopulated) {
|
||||
database.transaction {
|
||||
val map = createMap()
|
||||
map.set(1, "X")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Entity
|
||||
@javax.persistence.Table(name = "persist_map_test")
|
||||
class PersistentMapEntry(
|
||||
@Id
|
||||
@Column(name = "key")
|
||||
var key: Long = -1,
|
||||
|
||||
@Column(name = "value", length = 16)
|
||||
var value: String = ""
|
||||
) : Serializable
|
||||
|
||||
class TestMap : AppendOnlyPersistentMap<Long, String, PersistentMapEntry, Long>(
|
||||
toPersistentEntityKey = { it },
|
||||
fromPersistentEntity = { Pair(it.key, it.value) },
|
||||
toPersistentEntity = { key: Long, value: String ->
|
||||
PersistentMapEntry().apply {
|
||||
this.key = key
|
||||
this.value = value
|
||||
}
|
||||
},
|
||||
persistentEntityClass = PersistentMapEntry::class.java
|
||||
) {
|
||||
fun pendingKeysIsEmpty() = pendingKeys.isEmpty()
|
||||
|
||||
fun invalidate() = cache.invalidateAll()
|
||||
}
|
||||
|
||||
fun createMap() = TestMap()
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package net.corda.node.services.persistence
|
||||
|
||||
import net.corda.node.internal.configureDatabase
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
|
||||
class TransactionCallbackTest {
|
||||
private val database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock())
|
||||
|
||||
@After
|
||||
fun closeDatabase() {
|
||||
database.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `onCommit called and onRollback not called on commit`() {
|
||||
var onCommitCount = 0
|
||||
var onRollbackCount = 0
|
||||
database.transaction {
|
||||
onCommit { onCommitCount++ }
|
||||
onRollback { onRollbackCount++ }
|
||||
}
|
||||
assertEquals(1, onCommitCount)
|
||||
assertEquals(0, onRollbackCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `onCommit not called and onRollback called on rollback`() {
|
||||
class TestException : Exception()
|
||||
|
||||
var onCommitCount = 0
|
||||
var onRollbackCount = 0
|
||||
try {
|
||||
database.transaction {
|
||||
onCommit { onCommitCount++ }
|
||||
onRollback { onRollbackCount++ }
|
||||
throw TestException()
|
||||
}
|
||||
} catch (e: TestException) {
|
||||
}
|
||||
assertEquals(0, onCommitCount)
|
||||
assertEquals(1, onRollbackCount)
|
||||
}
|
||||
}
|
@ -0,0 +1,166 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.flows.InitiatedBy
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.internal.StartedNode
|
||||
import net.corda.node.services.messaging.Message
|
||||
import net.corda.node.services.persistence.DBTransactionStorage
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import net.corda.testing.node.internal.InternalMockNetwork
|
||||
import net.corda.testing.node.internal.MessagingServiceSpy
|
||||
import net.corda.testing.node.internal.newContext
|
||||
import net.corda.testing.node.internal.setMessagingServiceSpy
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.sql.SQLException
|
||||
import java.time.Duration
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotNull
|
||||
import kotlin.test.assertNull
|
||||
|
||||
class RetryFlowMockTest {
|
||||
private lateinit var mockNet: InternalMockNetwork
|
||||
private lateinit var internalNodeA: StartedNode<InternalMockNetwork.MockNode>
|
||||
private lateinit var internalNodeB: StartedNode<InternalMockNetwork.MockNode>
|
||||
|
||||
@Before
|
||||
fun start() {
|
||||
mockNet = InternalMockNetwork(threadPerNode = true, cordappPackages = listOf(this.javaClass.`package`.name))
|
||||
internalNodeA = mockNet.createNode()
|
||||
internalNodeB = mockNet.createNode()
|
||||
mockNet.startNodes()
|
||||
RetryFlow.count = 0
|
||||
SendAndRetryFlow.count = 0
|
||||
RetryInsertFlow.count = 0
|
||||
}
|
||||
|
||||
private fun <T> StartedNode<InternalMockNetwork.MockNode>.startFlow(logic: FlowLogic<T>): CordaFuture<T> = this.services.startFlow(logic, this.services.newContext()).getOrThrow().resultFuture
|
||||
|
||||
@After
|
||||
fun cleanUp() {
|
||||
mockNet.stopNodes()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Single retry`() {
|
||||
assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get())
|
||||
assertEquals(2, RetryFlow.count)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Retry forever`() {
|
||||
Assertions.assertThatThrownBy {
|
||||
internalNodeA.startFlow(RetryFlow(Int.MAX_VALUE)).getOrThrow()
|
||||
}.isInstanceOf(LimitedRetryCausingError::class.java)
|
||||
assertEquals(5, RetryFlow.count)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Retry does not set senderUUID`() {
|
||||
val messagesSent = mutableListOf<Message>()
|
||||
val partyB = internalNodeB.info.legalIdentities.first()
|
||||
internalNodeA.setMessagingServiceSpy(object : MessagingServiceSpy(internalNodeA.network) {
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
|
||||
messagesSent.add(message)
|
||||
messagingService.send(message, target, retryId)
|
||||
}
|
||||
})
|
||||
internalNodeA.startFlow(SendAndRetryFlow(1, partyB)).get()
|
||||
assertNotNull(messagesSent.first().senderUUID)
|
||||
assertNull(messagesSent.last().senderUUID)
|
||||
assertEquals(2, SendAndRetryFlow.count)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Retry duplicate insert`() {
|
||||
assertEquals(Unit, internalNodeA.startFlow(RetryInsertFlow(1)).get())
|
||||
assertEquals(2, RetryInsertFlow.count)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Patient records do not leak in hospital`() {
|
||||
assertEquals(Unit, internalNodeA.startFlow(RetryFlow(1)).get())
|
||||
assertEquals(0, StaffedFlowHospital.numberOfPatients)
|
||||
assertEquals(2, RetryFlow.count)
|
||||
}
|
||||
}
|
||||
|
||||
class LimitedRetryCausingError : org.hibernate.exception.ConstraintViolationException("Test message", SQLException(), "Test constraint")
|
||||
|
||||
class RetryCausingError : SQLException("deadlock")
|
||||
|
||||
class RetryFlow(val i: Int) : FlowLogic<Unit>() {
|
||||
companion object {
|
||||
var count = 0
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
logger.info("Hello $count")
|
||||
if (count++ < i) {
|
||||
if (i == Int.MAX_VALUE) {
|
||||
throw LimitedRetryCausingError()
|
||||
} else {
|
||||
throw RetryCausingError()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatingFlow
|
||||
class SendAndRetryFlow(val i: Int, val other: Party) : FlowLogic<Unit>() {
|
||||
companion object {
|
||||
var count = 0
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
logger.info("Sending...")
|
||||
val session = initiateFlow(other)
|
||||
session.send("Boo")
|
||||
if (count++ < i) {
|
||||
throw RetryCausingError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(SendAndRetryFlow::class)
|
||||
class ReceiveFlow2(val other: FlowSession) : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val received = other.receive<String>().unwrap { it }
|
||||
logger.info("Received... $received")
|
||||
}
|
||||
}
|
||||
|
||||
class RetryInsertFlow(val i: Int) : FlowLogic<Unit>() {
|
||||
companion object {
|
||||
var count = 0
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
logger.info("Hello")
|
||||
doInsert()
|
||||
// Checkpoint so we roll back to here
|
||||
FlowLogic.sleep(Duration.ofSeconds(0))
|
||||
if (count++ < i) {
|
||||
doInsert()
|
||||
}
|
||||
}
|
||||
|
||||
private fun doInsert() {
|
||||
val tx = DBTransactionStorage.DBTransaction("Foo")
|
||||
contextTransaction.session.save(tx)
|
||||
}
|
||||
}
|
@ -28,7 +28,10 @@ import net.corda.nodeapi.internal.persistence.DatabaseTransaction
|
||||
import net.corda.testing.core.*
|
||||
import net.corda.testing.internal.TEST_TX_TIME
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.internal.vault.*
|
||||
import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID
|
||||
import net.corda.testing.internal.vault.DummyLinearContract
|
||||
import net.corda.testing.internal.vault.DummyLinearStateSchemaV1
|
||||
import net.corda.testing.internal.vault.VaultFiller
|
||||
import net.corda.testing.node.MockServices
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices
|
||||
import net.corda.testing.node.makeTestIdentityService
|
||||
@ -171,17 +174,11 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
|
||||
@JvmField
|
||||
val expectedEx: ExpectedException = ExpectedException.none()
|
||||
|
||||
@Suppress("LeakingThis")
|
||||
@Rule
|
||||
@JvmField
|
||||
val transactionRule = VaultQueryRollbackRule(this)
|
||||
|
||||
companion object {
|
||||
@ClassRule @JvmField
|
||||
val testSerialization = SerializationEnvironmentRule()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Helper method for generating a Persistent H2 test database
|
||||
*/
|
||||
@ -194,7 +191,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
|
||||
database.close()
|
||||
}
|
||||
|
||||
private fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
|
||||
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
|
||||
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
|
||||
_database.transaction {
|
||||
// create new states
|
||||
@ -1988,239 +1985,6 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic trackBy() tests
|
||||
*/
|
||||
|
||||
@Test
|
||||
fun trackCashStates_unconsumed() {
|
||||
val updates = database.transaction {
|
||||
val updates =
|
||||
// DOCSTART VaultQueryExample15
|
||||
vaultService.trackBy<Cash.State>().updates // UNCONSUMED default
|
||||
// DOCEND VaultQueryExample15
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
close() // transaction needs to be closed to trigger events
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 5) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackCashStates_consumed() {
|
||||
|
||||
val updates = database.transaction {
|
||||
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)
|
||||
val updates = vaultService.trackBy<Cash.State>(criteria).updates
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
consumeCash(100.POUNDS)
|
||||
|
||||
// consume more stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
close() // transaction needs to be closed to trigger events
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 1) {}
|
||||
require(produced.isEmpty()) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 5) {}
|
||||
require(produced.isEmpty()) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackCashStates_all() {
|
||||
val updates = database.transaction {
|
||||
val updates =
|
||||
database.transaction {
|
||||
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
|
||||
vaultService.trackBy<Cash.State>(criteria).updates
|
||||
}
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(99.POUNDS)
|
||||
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
close() // transaction needs to be closed to trigger events
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 5) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 1) {}
|
||||
require(produced.size == 1) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 5) {}
|
||||
require(produced.isEmpty()) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackLinearStates() {
|
||||
|
||||
val updates = database.transaction {
|
||||
// DOCSTART VaultQueryExample16
|
||||
val (snapshot, updates) = vaultService.trackBy<LinearState>()
|
||||
// DOCEND VaultQueryExample16
|
||||
assertThat(snapshot.states).hasSize(0)
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
close() // transaction needs to be closed to trigger events
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 10) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 3) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackDealStates() {
|
||||
val updates = database.transaction {
|
||||
// DOCSTART VaultQueryExample17
|
||||
val (snapshot, updates) = vaultService.trackBy<DealState>()
|
||||
// DOCEND VaultQueryExample17
|
||||
assertThat(snapshot.states).hasSize(0)
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
close()
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 3) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun unconsumedCashStatesForSpending_single_issuer_reference() {
|
||||
database.transaction {
|
||||
@ -2281,10 +2045,241 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
|
||||
*/
|
||||
}
|
||||
|
||||
class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by vaultQueryTestRule {
|
||||
class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
|
||||
|
||||
companion object {
|
||||
@ClassRule @JvmField
|
||||
val vaultQueryTestRule = VaultQueryTestRule()
|
||||
val delegate = VaultQueryTestRule()
|
||||
}
|
||||
|
||||
@Rule
|
||||
@JvmField
|
||||
val vaultQueryTestRule = delegate
|
||||
|
||||
/**
|
||||
* Dynamic trackBy() tests are H2 only, since rollback stops events being emitted.
|
||||
*/
|
||||
|
||||
@Test
|
||||
fun trackCashStates_unconsumed() {
|
||||
val updates = database.transaction {
|
||||
val updates =
|
||||
// DOCSTART VaultQueryExample15
|
||||
vaultService.trackBy<Cash.State>().updates // UNCONSUMED default
|
||||
// DOCEND VaultQueryExample15
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 5) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackCashStates_consumed() {
|
||||
|
||||
val updates = database.transaction {
|
||||
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)
|
||||
val updates = vaultService.trackBy<Cash.State>(criteria).updates
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
consumeCash(100.POUNDS)
|
||||
|
||||
// consume more stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 1) {}
|
||||
require(produced.isEmpty()) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 5) {}
|
||||
require(produced.isEmpty()) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackCashStates_all() {
|
||||
val updates = database.transaction {
|
||||
val updates =
|
||||
database.transaction {
|
||||
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
|
||||
vaultService.trackBy<Cash.State>(criteria).updates
|
||||
}
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 5, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(99.POUNDS)
|
||||
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 5) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 1) {}
|
||||
require(produced.size == 1) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.size == 5) {}
|
||||
require(produced.isEmpty()) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackLinearStates() {
|
||||
|
||||
val updates = database.transaction {
|
||||
// DOCSTART VaultQueryExample16
|
||||
val (snapshot, updates) = vaultService.trackBy<LinearState>()
|
||||
// DOCEND VaultQueryExample16
|
||||
assertThat(snapshot.states).hasSize(0)
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 10) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 3) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun trackDealStates() {
|
||||
val updates = database.transaction {
|
||||
// DOCSTART VaultQueryExample17
|
||||
val (snapshot, updates) = vaultService.trackBy<DealState>()
|
||||
// DOCEND VaultQueryExample17
|
||||
assertThat(snapshot.states).hasSize(0)
|
||||
|
||||
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 3, DUMMY_CASH_ISSUER)
|
||||
val linearStates = vaultFiller.fillWithSomeTestLinearStates(10).states
|
||||
val dealStates = vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")).states
|
||||
// add more cash
|
||||
vaultFiller.fillWithSomeTestCash(100.POUNDS, notaryServices, 1, DUMMY_CASH_ISSUER)
|
||||
// add another deal
|
||||
vaultFiller.fillWithSomeTestDeals(listOf("SAMPLE DEAL"))
|
||||
this.session.flush()
|
||||
|
||||
// consume stuff
|
||||
consumeCash(100.DOLLARS)
|
||||
vaultFiller.consumeDeals(dealStates.toList())
|
||||
vaultFiller.consumeLinearStates(linearStates.toList())
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
updates.expectEvents {
|
||||
sequence(
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 3) {}
|
||||
},
|
||||
expect { (consumed, produced, flowId) ->
|
||||
require(flowId == null) {}
|
||||
require(consumed.isEmpty()) {}
|
||||
require(produced.size == 1) {}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
@ -5,8 +5,8 @@ import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.tee
|
||||
import net.corda.node.internal.configureDatabase
|
||||
import net.corda.nodeapi.internal.persistence.*
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
@ -14,6 +14,7 @@ import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.io.Closeable
|
||||
import java.util.*
|
||||
import kotlin.test.fail
|
||||
|
||||
class ObservablesTests {
|
||||
private fun isInDatabaseTransaction() = contextTransactionOrNull != null
|
||||
@ -58,6 +59,72 @@ class ObservablesTests {
|
||||
assertThat(secondEvent.get()).isEqualTo(0 to false)
|
||||
}
|
||||
|
||||
class TestException : Exception("Synthetic exception for tests") {}
|
||||
|
||||
@Test
|
||||
fun `bufferUntilDatabaseCommit swallows if transaction rolled back`() {
|
||||
val database = createDatabase()
|
||||
|
||||
val source = PublishSubject.create<Int>()
|
||||
val observable: Observable<Int> = source
|
||||
|
||||
val firstEvent = SettableFuture.create<Pair<Int, Boolean>>()
|
||||
val secondEvent = SettableFuture.create<Pair<Int, Boolean>>()
|
||||
|
||||
observable.first().subscribe { firstEvent.set(it to isInDatabaseTransaction()) }
|
||||
observable.skip(1).first().subscribe { secondEvent.set(it to isInDatabaseTransaction()) }
|
||||
|
||||
try {
|
||||
database.transaction {
|
||||
val delayedSubject = source.bufferUntilDatabaseCommit()
|
||||
assertThat(source).isNotEqualTo(delayedSubject)
|
||||
delayedSubject.onNext(0)
|
||||
source.onNext(1)
|
||||
assertThat(firstEvent.isDone).isTrue()
|
||||
assertThat(secondEvent.isDone).isFalse()
|
||||
throw TestException()
|
||||
}
|
||||
fail("Should not have successfully completed transaction")
|
||||
} catch (e: TestException) {
|
||||
}
|
||||
assertThat(secondEvent.isDone).isFalse()
|
||||
|
||||
assertThat(firstEvent.get()).isEqualTo(1 to true)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `bufferUntilDatabaseCommit propagates error if transaction rolled back`() {
|
||||
val database = createDatabase()
|
||||
|
||||
val source = PublishSubject.create<Int>()
|
||||
val observable: Observable<Int> = source
|
||||
|
||||
val firstEvent = SettableFuture.create<Pair<Int, Boolean>>()
|
||||
val secondEvent = SettableFuture.create<Pair<Int, Boolean>>()
|
||||
|
||||
observable.first().subscribe({ firstEvent.set(it to isInDatabaseTransaction()) }, {})
|
||||
observable.skip(1).subscribe({ secondEvent.set(it to isInDatabaseTransaction()) }, {})
|
||||
observable.skip(1).subscribe({}, { secondEvent.set(2 to isInDatabaseTransaction()) })
|
||||
|
||||
try {
|
||||
database.transaction {
|
||||
val delayedSubject = source.bufferUntilDatabaseCommit(propagateRollbackAsError = true)
|
||||
assertThat(source).isNotEqualTo(delayedSubject)
|
||||
delayedSubject.onNext(0)
|
||||
source.onNext(1)
|
||||
assertThat(firstEvent.isDone).isTrue()
|
||||
assertThat(secondEvent.isDone).isFalse()
|
||||
throw TestException()
|
||||
}
|
||||
fail("Should not have successfully completed transaction")
|
||||
} catch (e: TestException) {
|
||||
}
|
||||
assertThat(secondEvent.isDone).isTrue()
|
||||
|
||||
assertThat(firstEvent.get()).isEqualTo(1 to true)
|
||||
assertThat(secondEvent.get()).isEqualTo(2 to false)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `bufferUntilDatabaseCommit delays until transaction closed repeatable`() {
|
||||
val database = createDatabase()
|
||||
|
@ -20,6 +20,8 @@ import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.messaging.*
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.testing.node.internal.InMemoryMessage
|
||||
@ -108,7 +110,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
get() = _receivedMessages
|
||||
internal val endpoints: List<InternalMockMessagingService> @Synchronized get() = handleEndpointMap.values.toList()
|
||||
/** Get a [List] of all the [MockMessagingService] endpoints **/
|
||||
val endpointsExternal: List<MockMessagingService> @Synchronized get() = handleEndpointMap.values.map{ MockMessagingService.createMockMessagingService(it) }.toList()
|
||||
val endpointsExternal: List<MockMessagingService> @Synchronized get() = handleEndpointMap.values.map { MockMessagingService.createMockMessagingService(it) }.toList()
|
||||
|
||||
/**
|
||||
* Creates a node at the given address: useful if you want to recreate a node to simulate a restart.
|
||||
@ -135,7 +137,10 @@ class InMemoryMessagingNetwork private constructor(
|
||||
?: emptyList() //TODO only notary can be distributed?
|
||||
synchronized(this) {
|
||||
val node = InMemoryMessaging(manuallyPumped, peerHandle, executor, database)
|
||||
handleEndpointMap[peerHandle] = node
|
||||
val oldNode = handleEndpointMap.put(peerHandle, node)
|
||||
if (oldNode != null) {
|
||||
node.inheritPendingRedelivery(oldNode)
|
||||
}
|
||||
serviceHandles.forEach {
|
||||
serviceToPeersMapping.getOrPut(it) { LinkedHashSet() }.add(peerHandle)
|
||||
}
|
||||
@ -161,7 +166,10 @@ class InMemoryMessagingNetwork private constructor(
|
||||
|
||||
@Synchronized
|
||||
private fun netNodeHasShutdown(peerHandle: PeerHandle) {
|
||||
handleEndpointMap.remove(peerHandle)
|
||||
val endpoint = handleEndpointMap[peerHandle]
|
||||
if (!(endpoint?.hasPendingDeliveries() ?: false)) {
|
||||
handleEndpointMap.remove(peerHandle)
|
||||
}
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
@ -266,6 +274,30 @@ class InMemoryMessagingNetwork private constructor(
|
||||
return transfer
|
||||
}
|
||||
|
||||
/**
|
||||
* When a new message handler is added, this implies we have started a new node. The add handler logic uses this to
|
||||
* push back any un-acknowledged messages for this peer onto the head of the queue (rather than the tail) to maintain message
|
||||
* delivery order. We push them back because their consumption was not complete and a restarted node would
|
||||
* see them re-delivered if this was Artemis.
|
||||
*/
|
||||
@Synchronized
|
||||
private fun unPopMessages(transfers: Collection<MessageTransfer>, us: PeerHandle) {
|
||||
messageReceiveQueues.compute(us) { _, existing ->
|
||||
if (existing == null) {
|
||||
LinkedBlockingQueue<MessageTransfer>().apply {
|
||||
addAll(transfers)
|
||||
}
|
||||
} else {
|
||||
existing.apply {
|
||||
val drained = mutableListOf<MessageTransfer>()
|
||||
existing.drainTo(drained)
|
||||
existing.addAll(transfers)
|
||||
existing.addAll(drained)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun pumpSendInternal(transfer: MessageTransfer) {
|
||||
when (transfer.recipients) {
|
||||
is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer)
|
||||
@ -338,6 +370,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
|
||||
|
||||
override val myAddress: PeerHandle get() = peerHandle
|
||||
override val ourSenderUUID: String = UUID.randomUUID().toString()
|
||||
|
||||
private val backgroundThread = if (manuallyPumped) null else
|
||||
thread(isDaemon = true, name = "In-memory message dispatcher") {
|
||||
@ -370,10 +403,16 @@ class InMemoryMessagingNetwork private constructor(
|
||||
Pair(handler, pending)
|
||||
}
|
||||
|
||||
transfers.forEach { pumpSendInternal(it) }
|
||||
unPopMessages(transfers, peerHandle)
|
||||
return handler
|
||||
}
|
||||
|
||||
fun inheritPendingRedelivery(other: InMemoryMessaging) {
|
||||
state.locked {
|
||||
pendingRedelivery.addAll(other.state.locked { pendingRedelivery })
|
||||
}
|
||||
}
|
||||
|
||||
override fun removeMessageHandler(registration: MessageHandlerRegistration) {
|
||||
check(running)
|
||||
state.locked { check(handlers.remove(registration as Handler)) }
|
||||
@ -405,8 +444,8 @@ class InMemoryMessagingNetwork private constructor(
|
||||
override fun cancelRedelivery(retryId: Long) {}
|
||||
|
||||
/** Returns the given (topic & session, data) pair as a newly created message object. */
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -470,13 +509,14 @@ class InMemoryMessagingNetwork private constructor(
|
||||
database.transaction {
|
||||
for (handler in deliverTo) {
|
||||
try {
|
||||
handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler())
|
||||
val receivedMessage = transfer.toReceivedMessage()
|
||||
state.locked { pendingRedelivery.add(transfer) }
|
||||
handler.callback(receivedMessage, handler, InMemoryDeduplicationHandler(receivedMessage, transfer))
|
||||
} catch (e: Exception) {
|
||||
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
|
||||
}
|
||||
}
|
||||
_receivedMessages.onNext(transfer)
|
||||
processedMessages += transfer.message.uniqueMessageId
|
||||
messagesInFlight.countDown()
|
||||
}
|
||||
}
|
||||
@ -493,13 +533,23 @@ class InMemoryMessagingNetwork private constructor(
|
||||
message.uniqueMessageId,
|
||||
message.debugTimestamp,
|
||||
sender.name)
|
||||
}
|
||||
|
||||
private class DummyDeduplicationHandler : DeduplicationHandler {
|
||||
override fun afterDatabaseTransaction() {
|
||||
}
|
||||
override fun insideDatabaseTransaction() {
|
||||
private inner class InMemoryDeduplicationHandler(override val receivedMessage: ReceivedMessage, val transfer: MessageTransfer) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val deduplicationHandler: DeduplicationHandler
|
||||
get() = this
|
||||
|
||||
override fun afterDatabaseTransaction() {
|
||||
this@InMemoryMessaging.state.locked { pendingRedelivery.remove(transfer) }
|
||||
}
|
||||
|
||||
override fun insideDatabaseTransaction() {
|
||||
processedMessages += transfer.message.uniqueMessageId
|
||||
}
|
||||
}
|
||||
|
||||
fun hasPendingDeliveries(): Boolean = state.locked { pendingRedelivery.isNotEmpty() }
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user