From 456c9a85e1d13aaebd6fb7fa9cbdee9494a6204f Mon Sep 17 00:00:00 2001 From: Stefano Franz Date: Wed, 17 Oct 2018 11:27:14 +0100 Subject: [PATCH] =?UTF-8?q?remove=20requirement=20to=20override=20default?= =?UTF-8?q?=20progress=20tracker=20for=20interacti=E2=80=A6=20(#3985)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove requirement to override default progress tracker for interactive shell - this is no longer needed * fix failing tests --- .../kotlin/net/corda/core/flows/FlowLogic.kt | 21 +++-- .../corda/core/utilities/ProgressTracker.kt | 84 +++++++++++-------- .../core/utilities/ProgressTrackerTest.kt | 42 +++++----- .../statemachine/FlowStateMachineImpl.kt | 6 +- .../statemachine/FlowFrameworkTests.kt | 6 +- .../corda/attachmentdemo/AttachmentDemo.kt | 13 ++- .../net/corda/tools/shell/InteractiveShell.kt | 10 +-- 7 files changed, 110 insertions(+), 72 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index e85439a601..10f101bdc7 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -64,8 +64,10 @@ abstract class FlowLogic { /** * Return the outermost [FlowLogic] instance, or null if not in a flow. */ - @Suppress("unused") @JvmStatic - val currentTopLevel: FlowLogic<*>? get() = (Strand.currentStrand() as? FlowStateMachine<*>)?.logic + @Suppress("unused") + @JvmStatic + val currentTopLevel: FlowLogic<*>? + get() = (Strand.currentStrand() as? FlowStateMachine<*>)?.logic /** * If on a flow, suspends the flow and only wakes it up after at least [duration] time has passed. Otherwise, @@ -123,10 +125,11 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentityAndCert: PartyAndCertificate get() { - return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity } - ?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!") - } + val ourIdentityAndCert: PartyAndCertificate + get() { + return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity } + ?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!") + } /** * Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node. @@ -141,9 +144,11 @@ abstract class FlowLogic { // Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we // create a fresh session for the Party, put it here and use it in subsequent deprecated calls. private val deprecatedPartySessionMap = HashMap() + private fun getDeprecatedSessionForParty(party: Party): FlowSession { return deprecatedPartySessionMap.getOrPut(party) { initiateFlow(party) } } + /** * Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. @@ -342,7 +347,7 @@ abstract class FlowLogic { * Note that this has to return a tracker before the flow is invoked. You can't change your mind half way * through. */ - open val progressTracker: ProgressTracker? = null + open val progressTracker: ProgressTracker? = ProgressTracker.DEFAULT_TRACKER() /** * This is where you fill out your business logic. @@ -383,7 +388,7 @@ abstract class FlowLogic { * * @return Returns null if this flow has no progress tracker. */ - fun trackStepsTree(): DataFeed>, List>>? { + fun trackStepsTree(): DataFeed>, List>>? { // TODO this is not threadsafe, needs an atomic get-step-and-subscribe return progressTracker?.let { DataFeed(it.allStepsLabels, it.stepsTreeChanges) diff --git a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt index f646f0569a..7741500c31 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt @@ -30,10 +30,11 @@ import java.util.* * using the [Observable] subscribeOn call. */ @CordaSerializable -class ProgressTracker(vararg steps: Step) { +class ProgressTracker(vararg inputSteps: Step) { + @CordaSerializable sealed class Change(val progressTracker: ProgressTracker) { - data class Position(val tracker: ProgressTracker, val newStep: Step) : Change(tracker) { + data class Position(val tracker: ProgressTracker, val newStep: Step) : Change(tracker) { override fun toString() = newStep.label } @@ -64,6 +65,10 @@ class ProgressTracker(vararg steps: Step) { override fun equals(other: Any?) = other === UNSTARTED } + object STARTING : Step("Starting") { + override fun equals(other: Any?) = other === STARTING + } + object DONE : Step("Done") { override fun equals(other: Any?) = other === DONE } @@ -74,7 +79,7 @@ class ProgressTracker(vararg steps: Step) { private val childProgressTrackers = mutableMapOf() /** The steps in this tracker, same as the steps passed to the constructor but with UNSTARTED and DONE inserted. */ - val steps = arrayOf(UNSTARTED, *steps, DONE) + val steps = arrayOf(UNSTARTED, STARTING, *inputSteps, DONE) private var _allStepsCache: List> = _allSteps() @@ -83,42 +88,16 @@ class ProgressTracker(vararg steps: Step) { private val _stepsTreeChanges by transient { PublishSubject.create>>() } private val _stepsTreeIndexChanges by transient { PublishSubject.create() } - - - init { - steps.forEach { - val childTracker = it.childProgressTracker() - if (childTracker != null) { - setChildProgressTracker(it, childTracker) - } - } - } - - /** The zero-based index of the current step in the [steps] array (i.e. with UNSTARTED and DONE) */ - var stepIndex: Int = 0 - private set(value) { - field = value - } - - /** The zero-bases index of the current step in a [allStepsLabels] list */ - var stepsTreeIndex: Int = -1 - private set(value) { - field = value - _stepsTreeIndexChanges.onNext(value) - } - - /** - * Reading returns the value of steps[stepIndex], writing moves the position of the current tracker. Once moved to - * the [DONE] state, this tracker is finished and the current step cannot be moved again. - */ var currentStep: Step get() = steps[stepIndex] set(value) { - check(!hasEnded) { "Cannot rewind a progress tracker once it has ended" } + check((value === DONE && hasEnded) || !hasEnded) { + "Cannot rewind a progress tracker once it has ended" + } if (currentStep == value) return val index = steps.indexOf(value) - require(index != -1, { "Step ${value.label} not found in progress tracker." }) + require(index != -1) { "Step ${value.label} not found in progress tracker." } if (index < stepIndex) { // We are going backwards: unlink and unsubscribe from any child nodes that we're rolling back @@ -144,6 +123,39 @@ class ProgressTracker(vararg steps: Step) { } } + + init { + steps.forEach { + configureChildTrackerForStep(it) + } + this.currentStep = UNSTARTED + } + + private fun configureChildTrackerForStep(it: Step) { + val childTracker = it.childProgressTracker() + if (childTracker != null) { + setChildProgressTracker(it, childTracker) + } + } + + /** The zero-based index of the current step in the [steps] array (i.e. with UNSTARTED and DONE) */ + var stepIndex: Int = 0 + private set(value) { + field = value + } + + /** The zero-bases index of the current step in a [allStepsLabels] list */ + var stepsTreeIndex: Int = -1 + private set(value) { + field = value + _stepsTreeIndexChanges.onNext(value) + } + + /** + * Reading returns the value of steps[stepIndex], writing moves the position of the current tracker. Once moved to + * the [DONE] state, this tracker is finished and the current step cannot be moved again. + */ + /** Returns the current step, descending into children to find the deepest step we are up to. */ val currentStepRecursive: Step get() = getChildProgressTracker(currentStep)?.currentStepRecursive ?: currentStep @@ -263,7 +275,7 @@ class ProgressTracker(vararg steps: Step) { /** * An observable stream of changes to the [allStepsLabels] */ - val stepsTreeChanges: Observable>> get() = _stepsTreeChanges + val stepsTreeChanges: Observable>> get() = _stepsTreeChanges /** * An observable stream of changes to the [stepsTreeIndex] @@ -272,6 +284,10 @@ class ProgressTracker(vararg steps: Step) { /** Returns true if the progress tracker has ended, either by reaching the [DONE] step or prematurely with an error */ val hasEnded: Boolean get() = _changes.hasCompleted() || _changes.hasThrowable() + + companion object { + val DEFAULT_TRACKER = { ProgressTracker() } + } } // TODO: Expose the concept of errors. // TODO: It'd be helpful if this class was at least partly thread safe. diff --git a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt index 6a682faff6..e476df1a8b 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt @@ -50,11 +50,11 @@ class ProgressTrackerTest { assertEquals(0, pt.stepIndex) var stepNotification: ProgressTracker.Step? = null pt.changes.subscribe { stepNotification = (it as? ProgressTracker.Change.Position)?.newStep } - + assertEquals(ProgressTracker.UNSTARTED, pt.currentStep) + assertEquals(ProgressTracker.STARTING, pt.nextStep()) assertEquals(SimpleSteps.ONE, pt.nextStep()) - assertEquals(1, pt.stepIndex) + assertEquals(2, pt.stepIndex) assertEquals(SimpleSteps.ONE, stepNotification) - assertEquals(SimpleSteps.TWO, pt.nextStep()) assertEquals(SimpleSteps.THREE, pt.nextStep()) assertEquals(SimpleSteps.FOUR, pt.nextStep()) @@ -87,8 +87,10 @@ class ProgressTrackerTest { assertEquals(SimpleSteps.TWO, (stepNotification.pollFirst() as ProgressTracker.Change.Structural).parent) assertNextStep(SimpleSteps.TWO) + assertEquals(pt2.currentStep, ProgressTracker.UNSTARTED) + assertEquals(ProgressTracker.STARTING, pt2.nextStep()) assertEquals(ChildSteps.AYY, pt2.nextStep()) - assertNextStep(ChildSteps.AYY) + assertEquals((stepNotification.last as ProgressTracker.Change.Position).newStep, ChildSteps.AYY) assertEquals(ChildSteps.BEE, pt2.nextStep()) } @@ -115,19 +117,19 @@ class ProgressTrackerTest { // Travel tree. pt.currentStep = SimpleSteps.ONE - assertCurrentStepsTree(0, SimpleSteps.ONE) + assertCurrentStepsTree(1, SimpleSteps.ONE) pt.currentStep = SimpleSteps.TWO - assertCurrentStepsTree(1, SimpleSteps.TWO) + assertCurrentStepsTree(2, SimpleSteps.TWO) pt2.currentStep = ChildSteps.BEE - assertCurrentStepsTree(3, ChildSteps.BEE) + assertCurrentStepsTree(5, ChildSteps.BEE) pt.currentStep = SimpleSteps.THREE - assertCurrentStepsTree(5, SimpleSteps.THREE) + assertCurrentStepsTree(7, SimpleSteps.THREE) // Assert no structure changes and proper steps propagation. - assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(0, 1, 3, 5)) + assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 2, 5, 7)) assertThat(stepsTreeNotification).isEmpty() } @@ -153,16 +155,16 @@ class ProgressTrackerTest { } pt.currentStep = SimpleSteps.ONE - assertCurrentStepsTree(0, SimpleSteps.ONE) + assertCurrentStepsTree(1, SimpleSteps.ONE) pt.currentStep = SimpleSteps.FOUR - assertCurrentStepsTree(3, SimpleSteps.FOUR) + assertCurrentStepsTree(4, SimpleSteps.FOUR) pt2.currentStep = ChildSteps.SEA - assertCurrentStepsTree(6, ChildSteps.SEA) + assertCurrentStepsTree(8, ChildSteps.SEA) // Assert no structure changes and proper steps propagation. - assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(0, 3, 6)) + assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 4, 8)) assertThat(stepsTreeNotification).isEmpty() } @@ -189,18 +191,18 @@ class ProgressTrackerTest { } pt.currentStep = SimpleSteps.TWO - assertCurrentStepsTree(1, SimpleSteps.TWO) + assertCurrentStepsTree(2, SimpleSteps.TWO) pt.currentStep = SimpleSteps.FOUR - assertCurrentStepsTree(6, SimpleSteps.FOUR) + assertCurrentStepsTree(8, SimpleSteps.FOUR) pt.setChildProgressTracker(SimpleSteps.THREE, pt3) - assertCurrentStepsTree(9, SimpleSteps.FOUR) + assertCurrentStepsTree(12, SimpleSteps.FOUR) // Assert no structure changes and proper steps propagation. - assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 6, 9)) + assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(2, 8, 12)) assertThat(stepsTreeNotification).hasSize(2) // 1 change + 1 our initial state } @@ -228,14 +230,14 @@ class ProgressTrackerTest { pt.currentStep = SimpleSteps.TWO pt2.currentStep = ChildSteps.SEA pt3.currentStep = BabySteps.UNOS - assertCurrentStepsTree(4, ChildSteps.SEA) + assertCurrentStepsTree(6, ChildSteps.SEA) pt.setChildProgressTracker(SimpleSteps.TWO, pt3) - assertCurrentStepsTree(2, BabySteps.UNOS) + assertCurrentStepsTree(4, BabySteps.UNOS) // Assert no structure changes and proper steps propagation. - assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 4, 2)) + assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(2, 6, 4)) assertThat(stepsTreeNotification).hasSize(2) // 1 change + 1 our initial state. } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4fcd24e7d6..34b81ff9df 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -14,6 +14,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.* import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.debug import net.corda.core.utilities.trace @@ -205,6 +206,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable override fun run() { + logic.progressTracker?.currentStep = ProgressTracker.STARTING logic.stateMachine = this setLoggingContext() @@ -263,7 +265,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, processEventImmediately( Event.EnterSubFlow(subFlow.javaClass, createSubFlowVersion( - serviceHub.cordappProvider.getCordappForFlow(subFlow), serviceHub.myInfo.platformVersion + serviceHub.cordappProvider.getCordappForFlow(subFlow), serviceHub.myInfo.platformVersion ) ), isDbTransactionOpenOnEntry = true, @@ -435,7 +437,7 @@ val Class>.flowVersionAndInitiatingClass: Pair() { + @Suspendable + override fun call(): String { + return "You Called me!" + } +} + + @Suppress("DEPRECATION") // DOCSTART 1 fun recipient(rpc: CordaRPCOps, webPort: Int) { diff --git a/tools/shell/src/main/kotlin/net/corda/tools/shell/InteractiveShell.kt b/tools/shell/src/main/kotlin/net/corda/tools/shell/InteractiveShell.kt index 6729ab0ea2..55109ff233 100644 --- a/tools/shell/src/main/kotlin/net/corda/tools/shell/InteractiveShell.kt +++ b/tools/shell/src/main/kotlin/net/corda/tools/shell/InteractiveShell.kt @@ -17,7 +17,10 @@ import net.corda.core.flows.FlowLogic import net.corda.core.internal.* import net.corda.core.internal.concurrent.doneFuture import net.corda.core.internal.concurrent.openFuture -import net.corda.core.messaging.* +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.DataFeed +import net.corda.core.messaging.FlowProgressHandle +import net.corda.core.messaging.StateMachineUpdate import net.corda.nodeapi.internal.pendingFlowsCount import net.corda.tools.shell.utlities.ANSIProgressRenderer import net.corda.tools.shell.utlities.StdoutANSIProgressRenderer @@ -359,11 +362,6 @@ object InteractiveShell { errors.add("${getPrototype()}: Wrong number of arguments (${args.size} provided, ${ctor.genericParameterTypes.size} needed)") continue } - val flow = ctor.newInstance(*args) as FlowLogic<*> - if (flow.progressTracker == null) { - errors.add("A flow must override the progress tracker in order to be run from the shell") - continue - } return invoke(clazz, args) } catch (e: StringToMethodCallParser.UnparseableCallException.MissingParameter) { errors.add("${getPrototype()}: missing parameter ${e.paramName}")