From b73a4980626daee540c3151e68314985f0ca277d Mon Sep 17 00:00:00 2001 From: Dimos Raptis Date: Wed, 25 Mar 2020 09:02:14 +0000 Subject: [PATCH] [ENT-4754] - Move subflow preparation logic in FlowStateMachine --- .../kotlin/net/corda/core/flows/FlowLogic.kt | 16 +--------------- .../net/corda/core/internal/FlowStateMachine.kt | 2 +- .../statemachine/FlowStateMachineImpl.kt | 17 ++++++++++++++++- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 808c68fbe7..79c7e1c36d 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -380,10 +380,8 @@ abstract class FlowLogic { @Suspendable @Throws(FlowException::class) open fun subFlow(subLogic: FlowLogic): R { - subLogic.stateMachine = stateMachine - maybeWireUpProgressTracking(subLogic) logger.debug { "Calling subflow: $subLogic" } - val result = stateMachine.subFlow(subLogic) + val result = stateMachine.subFlow(this, subLogic) logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" } return result } @@ -540,18 +538,6 @@ abstract class FlowLogic { _stateMachine = value } - private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) { - val ours = progressTracker - val theirs = subLogic.progressTracker - if (ours != null && theirs != null && ours != theirs) { - if (ours.currentStep == ProgressTracker.UNSTARTED) { - logger.debug { "Initializing the progress tracker for flow: ${this::class.java.name}." } - ours.nextStep() - } - ours.setChildProgressTracker(ours.currentStep, theirs) - } - } - private fun enforceNoDuplicates(sessions: List) { require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." } } diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index d32062eac2..c057efa31e 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -25,7 +25,7 @@ interface FlowStateMachine { fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) @Suspendable - fun subFlow(subFlow: FlowLogic): SUBFLOWRETURN + fun subFlow(currentFlow: FlowLogic<*>, subFlow: FlowLogic): SUBFLOWRETURN @Suspendable fun flowStackSnapshot(flowClass: Class>): FlowStackSnapshot? diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4a9a407473..1791afc7e5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -315,7 +315,10 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun subFlow(subFlow: FlowLogic): R { + override fun subFlow(currentFlow: FlowLogic<*>, subFlow: FlowLogic): R { + subFlow.stateMachine = this + maybeWireUpProgressTracking(currentFlow, subFlow) + checkpointIfSubflowIdempotent(subFlow.javaClass) processEventImmediately( Event.EnterSubFlow(subFlow.javaClass, @@ -338,6 +341,18 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } } + private fun maybeWireUpProgressTracking(currentFlow: FlowLogic<*>, subFlow: FlowLogic<*>) { + val currentFlowProgressTracker = currentFlow.progressTracker + val subflowProgressTracker = subFlow.progressTracker + if (currentFlowProgressTracker != null && subflowProgressTracker != null && currentFlowProgressTracker != subflowProgressTracker) { + if (currentFlowProgressTracker.currentStep == ProgressTracker.UNSTARTED) { + logger.debug { "Initializing the progress tracker for flow: ${this::class.java.name}." } + currentFlowProgressTracker.nextStep() + } + currentFlowProgressTracker.setChildProgressTracker(currentFlowProgressTracker.currentStep, subflowProgressTracker) + } + } + private fun Throwable.isUnrecoverable(): Boolean = this is VirtualMachineError && this !is StackOverflowError /**