diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt index 56cff779f3..b28c857cde 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt @@ -69,7 +69,7 @@ abstract class ProtocolLogic { val ours = progressTracker val theirs = subLogic.progressTracker if (ours != null && theirs != null) - ours.childrenFor[ours.currentStep] = theirs + ours.setChildProgressTracker(ours.currentStep, theirs) } /** diff --git a/core/src/main/kotlin/com/r3corda/core/utilities/ProgressTracker.kt b/core/src/main/kotlin/com/r3corda/core/utilities/ProgressTracker.kt index 8dfdca284e..7b4d2a1cb5 100644 --- a/core/src/main/kotlin/com/r3corda/core/utilities/ProgressTracker.kt +++ b/core/src/main/kotlin/com/r3corda/core/utilities/ProgressTracker.kt @@ -49,7 +49,8 @@ class ProgressTracker(vararg steps: Step) { /** The superclass of all step objects. */ open class Step(open val label: String) { - open val changes: Observable = Observable.empty() + open val changes: Observable get() = Observable.empty() + open fun childProgressTracker(): ProgressTracker? = null } /** This class makes it easier to relabel a step on the fly, to provide transient information. */ @@ -77,6 +78,20 @@ class ProgressTracker(vararg steps: Step) { /** The steps in this tracker, same as the steps passed to the constructor but with UNSTARTED and DONE inserted. */ val steps = arrayOf(UNSTARTED, *steps, DONE) + // This field won't be serialized. + private val _changes by TransientProperty { PublishSubject.create() } + + private val childProgressTrackers = HashMap>() + + init { + steps.forEach { + val childTracker = it.childProgressTracker() + if (childTracker != null) { + setChildProgressTracker(it, childTracker) + } + } + } + /** The zero-based index of the current step in the [steps] array (i.e. with UNSTARTED and DONE) */ var stepIndex: Int = 0 private set @@ -98,7 +113,7 @@ class ProgressTracker(vararg steps: Step) { // We are going backwards: unlink and unsubscribe from any child nodes that we're rolling back // through, in preparation for moving through them again. for (i in stepIndex downTo index) { - childrenFor.remove(steps[i]) + removeChildProgressTracker(steps[i]) } } @@ -113,16 +128,28 @@ class ProgressTracker(vararg steps: Step) { /** Returns the current step, descending into children to find the deepest step we are up to. */ val currentStepRecursive: Step - get() = childrenFor[currentStep]?.currentStepRecursive ?: currentStep + get() = getChildProgressTracker(currentStep)?.currentStepRecursive ?: currentStep - /** - * Writable map that lets you insert child [ProgressTracker]s for particular steps. It's OK to edit this even - * after a progress tracker has been started. - */ - val childrenFor: ChildrenProgressTrackers = ChildrenProgressTrackersImpl() + fun getChildProgressTracker(step: Step): ProgressTracker? = childProgressTrackers[step]?.first + + fun setChildProgressTracker(step: ProgressTracker.Step, childProgressTracker: ProgressTracker) { + val subscription = childProgressTracker.changes.subscribe({ _changes.onNext(it) }, { _changes.onError(it) }) + childProgressTrackers[step] = Pair(childProgressTracker, subscription) + childProgressTracker.parent = this + _changes.onNext(Change.Structural(this, step)) + } + + private fun removeChildProgressTracker(step: ProgressTracker.Step) { + childProgressTrackers.remove(step)?.let { + it.first.parent = null + it.second.unsubscribe() + } + _changes.onNext(Change.Structural(this, step)) + } /** The parent of this tracker: set automatically by the parent when a tracker is added as a child */ var parent: ProgressTracker? = null + private set /** Walks up the tree to find the top level tracker. If this is the top level tracker, returns 'this' */ val topLevelTracker: ProgressTracker @@ -138,7 +165,7 @@ class ProgressTracker(vararg steps: Step) { if (step == UNSTARTED) continue if (level > 0 && step == DONE) continue result += Pair(level, step) - childrenFor[step]?.let { result += it._allSteps(level + 1) } + getChildProgressTracker(step)?.let { result += it._allSteps(level + 1) } } return result } @@ -160,45 +187,12 @@ class ProgressTracker(vararg steps: Step) { return currentStep } - // This field won't be serialized. - private val _changes by TransientProperty { PublishSubject.create() } - /** * An observable stream of changes: includes child steps, resets and any changes emitted by individual steps (e.g. * if a step changed its label or rendering). */ val changes: Observable get() = _changes - - // TODO remove this interface and add its three methods directly into ProgressTracker - interface ChildrenProgressTrackers { - operator fun get(step: ProgressTracker.Step): ProgressTracker? - operator fun set(step: ProgressTracker.Step, childProgressTracker: ProgressTracker) - fun remove(step: ProgressTracker.Step) - } - - private inner class ChildrenProgressTrackersImpl : ChildrenProgressTrackers { - - private val map = HashMap>() - - override fun get(step: Step): ProgressTracker? = map[step]?.first - - override fun set(step: Step, childProgressTracker: ProgressTracker) { - val subscription = childProgressTracker.changes.subscribe({ _changes.onNext(it) }, { _changes.onError(it) }) - map[step] = Pair(childProgressTracker, subscription) - childProgressTracker.parent = this@ProgressTracker - _changes.onNext(Change.Structural(this@ProgressTracker, step)) - } - - override fun remove(step: Step) { - map.remove(step)?.let { - it.first.parent = null - it.second.unsubscribe() - } - _changes.onNext(Change.Structural(this@ProgressTracker, step)) - } - } - } diff --git a/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt index 20b299602a..f2b1c24b3f 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt @@ -341,7 +341,7 @@ object TwoPartyDealProtocol { override val progressTracker: ProgressTracker = replacementProgressTracker ?: createTracker() fun createTracker(): ProgressTracker = Secondary.tracker().apply { - childrenFor[SIGNING] = ratesFixTracker + setChildProgressTracker(SIGNING, ratesFixTracker) } override fun validateHandshake(handshake: Handshake): Handshake { diff --git a/core/src/test/kotlin/com/r3corda/core/utilities/ProgressTrackerTest.kt b/core/src/test/kotlin/com/r3corda/core/utilities/ProgressTrackerTest.kt index 572a575486..c6e736706e 100644 --- a/core/src/test/kotlin/com/r3corda/core/utilities/ProgressTrackerTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/utilities/ProgressTrackerTest.kt @@ -70,7 +70,7 @@ class ProgressTrackerTest { pt.currentStep = SimpleSteps.ONE assertNextStep(SimpleSteps.ONE) - pt.childrenFor[SimpleSteps.TWO] = pt2 + pt.setChildProgressTracker(SimpleSteps.TWO, pt2) pt.nextStep() assertEquals(SimpleSteps.TWO, (stepNotification.pollFirst() as ProgressTracker.Change.Structural).parent) assertNextStep(SimpleSteps.TWO) @@ -83,7 +83,7 @@ class ProgressTrackerTest { @Test fun `can be rewound`() { val pt2 = ChildSteps.tracker() - pt.childrenFor[SimpleSteps.TWO] = pt2 + pt.setChildProgressTracker(SimpleSteps.TWO, pt2) repeat(4) { pt.nextStep() } pt.currentStep = SimpleSteps.ONE assertEquals(SimpleSteps.TWO, pt.nextStep()) diff --git a/node/src/main/kotlin/com/r3corda/node/utilities/ANSIProgressRenderer.kt b/node/src/main/kotlin/com/r3corda/node/utilities/ANSIProgressRenderer.kt index dcbb7ccd18..e9bcf5bf99 100644 --- a/node/src/main/kotlin/com/r3corda/node/utilities/ANSIProgressRenderer.kt +++ b/node/src/main/kotlin/com/r3corda/node/utilities/ANSIProgressRenderer.kt @@ -146,7 +146,7 @@ object ANSIProgressRenderer { newline() lines++ - val child = childrenFor[step] + val child = getChildProgressTracker(step) if (child != null) lines += child.renderLevel(ansi, indent + 1, allSteps) } diff --git a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt index 63aec87b2f..107ff7ed1e 100644 --- a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt +++ b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt @@ -26,7 +26,6 @@ import com.r3corda.node.services.messaging.ArtemisMessagingService import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.transactions.SimpleNotaryService -import com.r3corda.node.utilities.ANSIProgressRenderer import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.TwoPartyTradeProtocol import com.typesafe.config.ConfigFactory @@ -293,14 +292,14 @@ class TraderDemoProtocolSeller(val myAddress: HostAndPort, object SELF_ISSUING : ProgressTracker.Step("Got session ID back, issuing and timestamping some commercial paper") - object TRADING : ProgressTracker.Step("Starting the trade protocol") + object TRADING : ProgressTracker.Step("Starting the trade protocol") { + override fun childProgressTracker(): ProgressTracker = TwoPartyTradeProtocol.Seller.tracker() + } // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // surprised when it appears as a new set of tasks below the current one. - fun tracker() = ProgressTracker(ANNOUNCING, SELF_ISSUING, TRADING).apply { - childrenFor[TRADING] = TwoPartyTradeProtocol.Seller.tracker() - } + fun tracker() = ProgressTracker(ANNOUNCING, SELF_ISSUING, TRADING) } @Suspendable @@ -318,7 +317,7 @@ class TraderDemoProtocolSeller(val myAddress: HostAndPort, progressTracker.currentStep = TRADING val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, commercialPaper, 1000.DOLLARS, cpOwnerKey, - sessionID, progressTracker.childrenFor[TRADING]!!) + sessionID, progressTracker.getChildProgressTracker(TRADING)!!) val tradeTX: SignedTransaction = subProtocol(seller) serviceHub.recordTransactions(listOf(tradeTX)) diff --git a/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt index 38cbbc29ae..c90561608c 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt @@ -30,12 +30,12 @@ object AutoOfferProtocol { object Handler { object RECEIVED : ProgressTracker.Step("Received offer") - object DEALING : ProgressTracker.Step("Starting the deal protocol") - - fun tracker() = ProgressTracker(RECEIVED, DEALING).apply { - childrenFor[DEALING] = TwoPartyDealProtocol.Primary.tracker() + object DEALING : ProgressTracker.Step("Starting the deal protocol") { + override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker() } + fun tracker() = ProgressTracker(RECEIVED, DEALING) + class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback { override fun onFailure(t: Throwable?) { // TODO handle exceptions @@ -56,7 +56,7 @@ object AutoOfferProtocol { // TODO required as messaging layer does not currently queue messages that arrive before we expect them Thread.sleep(100) val seller = TwoPartyDealProtocol.Instigator(autoOfferMessage.otherSide, node.services.networkMapCache.notaryNodes.first(), - autoOfferMessage.dealBeingOffered, node.services.keyManagementService.freshKey(), autoOfferMessage.otherSessionID, progressTracker.childrenFor[DEALING]!!) + autoOfferMessage.dealBeingOffered, node.services.keyManagementService.freshKey(), autoOfferMessage.otherSessionID, progressTracker.getChildProgressTracker(DEALING)!!) val future = node.smm.add("${TwoPartyDealProtocol.DEAL_TOPIC}.seller", seller) // This is required because we are doing child progress outside of a subprotocol. In future, we should just wrap things like this in a protocol to avoid it Futures.addCallback(future, Callback() { @@ -73,14 +73,14 @@ object AutoOfferProtocol { companion object { object RECEIVED : ProgressTracker.Step("Received API call") object ANNOUNCING : ProgressTracker.Step("Announcing to the peer node") - object DEALING : ProgressTracker.Step("Starting the deal protocol") + object DEALING : ProgressTracker.Step("Starting the deal protocol") { + override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Secondary.tracker() + } // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // surprised when it appears as a new set of tasks below the current one. - fun tracker() = ProgressTracker(RECEIVED, ANNOUNCING, DEALING).apply { - childrenFor[DEALING] = TwoPartyDealProtocol.Secondary.tracker() - } + fun tracker() = ProgressTracker(RECEIVED, ANNOUNCING, DEALING) } override val progressTracker = tracker() @@ -103,7 +103,7 @@ object AutoOfferProtocol { progressTracker.currentStep = ANNOUNCING send(TOPIC, otherSide, 0, AutoOfferMessage(serviceHub.networkService.myAddress, ourSessionID, dealToBeOffered)) progressTracker.currentStep = DEALING - val stx = subProtocol(TwoPartyDealProtocol.Acceptor(otherSide, notary.identity, dealToBeOffered, ourSessionID, progressTracker.childrenFor[DEALING]!!)) + val stx = subProtocol(TwoPartyDealProtocol.Acceptor(otherSide, notary.identity, dealToBeOffered, ourSessionID, progressTracker.getChildProgressTracker(DEALING)!!)) return stx } diff --git a/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt index 17788c6dc5..e32bc610a3 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt @@ -94,7 +94,7 @@ object UpdateBusinessDayProtocol { @Suspendable private fun nextFixingFloatingLeg(dealStateAndRef: StateAndRef, party: NodeInfo, sessionID: Long): StateAndRef? { - progressTracker.childrenFor[FIXING] = TwoPartyDealProtocol.Primary.tracker() + progressTracker.setChildProgressTracker(FIXING, TwoPartyDealProtocol.Primary.tracker()) progressTracker.currentStep = FIXING val myName = serviceHub.storageService.myLegalIdentity.name @@ -103,17 +103,22 @@ object UpdateBusinessDayProtocol { val keyPair = serviceHub.keyManagementService.toKeyPair(myOldParty.owningKey) val participant = TwoPartyDealProtocol.Floater(party.address, sessionID, serviceHub.networkMapCache.notaryNodes[0], dealStateAndRef, keyPair, - sessionID, progressTracker.childrenFor[FIXING]!!) + sessionID, progressTracker.getChildProgressTracker(FIXING)!!) val result = subProtocol(participant) return result.tx.outRef(0) } @Suspendable private fun nextFixingFixedLeg(dealStateAndRef: StateAndRef, party: NodeInfo, sessionID: Long): StateAndRef? { - progressTracker.childrenFor[FIXING] = TwoPartyDealProtocol.Secondary.tracker() + progressTracker.setChildProgressTracker(FIXING, TwoPartyDealProtocol.Secondary.tracker()) progressTracker.currentStep = FIXING - val participant = TwoPartyDealProtocol.Fixer(party.address, serviceHub.networkMapCache.notaryNodes[0].identity, dealStateAndRef, sessionID, progressTracker.childrenFor[FIXING]!!) + val participant = TwoPartyDealProtocol.Fixer( + party.address, + serviceHub.networkMapCache.notaryNodes[0].identity, + dealStateAndRef, + sessionID, + progressTracker.getChildProgressTracker(FIXING)!!) val result = subProtocol(participant) return result.tx.outRef(0) } @@ -139,11 +144,11 @@ object UpdateBusinessDayProtocol { companion object { object NOTIFYING : ProgressTracker.Step("Notifying peer") - object LOCAL : ProgressTracker.Step("Updating locally") - - fun tracker() = ProgressTracker(NOTIFYING, LOCAL).apply { - childrenFor[LOCAL] = Updater.tracker() + object LOCAL : ProgressTracker.Step("Updating locally") { + override fun childProgressTracker(): ProgressTracker = Updater.tracker() } + + fun tracker() = ProgressTracker(NOTIFYING, LOCAL) } @Suspendable @@ -156,7 +161,7 @@ object UpdateBusinessDayProtocol { } if ((serviceHub.clock as DemoClock).updateDate(message.date)) { progressTracker.currentStep = LOCAL - subProtocol(Updater(message.date, message.sessionID, progressTracker.childrenFor[LOCAL]!!)) + subProtocol(Updater(message.date, message.sessionID, progressTracker.getChildProgressTracker(LOCAL)!!)) } return true }