remove requirement to override default progress tracker for interacti… (#3985)

* remove requirement to override default progress tracker for interactive shell - this is no longer needed

* fix failing tests
This commit is contained in:
Stefano Franz 2018-10-17 11:27:14 +01:00 committed by GitHub
parent 715c38766d
commit 456c9a85e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 110 additions and 72 deletions

View File

@ -64,8 +64,10 @@ abstract class FlowLogic<out T> {
/**
* Return the outermost [FlowLogic] instance, or null if not in a flow.
*/
@Suppress("unused") @JvmStatic
val currentTopLevel: FlowLogic<*>? get() = (Strand.currentStrand() as? FlowStateMachine<*>)?.logic
@Suppress("unused")
@JvmStatic
val currentTopLevel: FlowLogic<*>?
get() = (Strand.currentStrand() as? FlowStateMachine<*>)?.logic
/**
* If on a flow, suspends the flow and only wakes it up after at least [duration] time has passed. Otherwise,
@ -123,10 +125,11 @@ abstract class FlowLogic<out T> {
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
* is implemented.
*/
val ourIdentityAndCert: PartyAndCertificate get() {
return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity }
?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!")
}
val ourIdentityAndCert: PartyAndCertificate
get() {
return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity }
?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!")
}
/**
* Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node.
@ -141,9 +144,11 @@ abstract class FlowLogic<out T> {
// Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we
// create a fresh session for the Party, put it here and use it in subsequent deprecated calls.
private val deprecatedPartySessionMap = HashMap<Party, FlowSession>()
private fun getDeprecatedSessionForParty(party: Party): FlowSession {
return deprecatedPartySessionMap.getOrPut(party) { initiateFlow(party) }
}
/**
* Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it
* provides the necessary information needed for the evolution of flows and enabling backwards compatibility.
@ -342,7 +347,7 @@ abstract class FlowLogic<out T> {
* Note that this has to return a tracker before the flow is invoked. You can't change your mind half way
* through.
*/
open val progressTracker: ProgressTracker? = null
open val progressTracker: ProgressTracker? = ProgressTracker.DEFAULT_TRACKER()
/**
* This is where you fill out your business logic.
@ -383,7 +388,7 @@ abstract class FlowLogic<out T> {
*
* @return Returns null if this flow has no progress tracker.
*/
fun trackStepsTree(): DataFeed<List<Pair<Int,String>>, List<Pair<Int,String>>>? {
fun trackStepsTree(): DataFeed<List<Pair<Int, String>>, List<Pair<Int, String>>>? {
// TODO this is not threadsafe, needs an atomic get-step-and-subscribe
return progressTracker?.let {
DataFeed(it.allStepsLabels, it.stepsTreeChanges)

View File

@ -30,10 +30,11 @@ import java.util.*
* using the [Observable] subscribeOn call.
*/
@CordaSerializable
class ProgressTracker(vararg steps: Step) {
class ProgressTracker(vararg inputSteps: Step) {
@CordaSerializable
sealed class Change(val progressTracker: ProgressTracker) {
data class Position(val tracker: ProgressTracker, val newStep: Step) : Change(tracker) {
data class Position(val tracker: ProgressTracker, val newStep: Step) : Change(tracker) {
override fun toString() = newStep.label
}
@ -64,6 +65,10 @@ class ProgressTracker(vararg steps: Step) {
override fun equals(other: Any?) = other === UNSTARTED
}
object STARTING : Step("Starting") {
override fun equals(other: Any?) = other === STARTING
}
object DONE : Step("Done") {
override fun equals(other: Any?) = other === DONE
}
@ -74,7 +79,7 @@ class ProgressTracker(vararg steps: Step) {
private val childProgressTrackers = mutableMapOf<Step, Child>()
/** The steps in this tracker, same as the steps passed to the constructor but with UNSTARTED and DONE inserted. */
val steps = arrayOf(UNSTARTED, *steps, DONE)
val steps = arrayOf(UNSTARTED, STARTING, *inputSteps, DONE)
private var _allStepsCache: List<Pair<Int, Step>> = _allSteps()
@ -83,42 +88,16 @@ class ProgressTracker(vararg steps: Step) {
private val _stepsTreeChanges by transient { PublishSubject.create<List<Pair<Int, String>>>() }
private val _stepsTreeIndexChanges by transient { PublishSubject.create<Int>() }
init {
steps.forEach {
val childTracker = it.childProgressTracker()
if (childTracker != null) {
setChildProgressTracker(it, childTracker)
}
}
}
/** The zero-based index of the current step in the [steps] array (i.e. with UNSTARTED and DONE) */
var stepIndex: Int = 0
private set(value) {
field = value
}
/** The zero-bases index of the current step in a [allStepsLabels] list */
var stepsTreeIndex: Int = -1
private set(value) {
field = value
_stepsTreeIndexChanges.onNext(value)
}
/**
* Reading returns the value of steps[stepIndex], writing moves the position of the current tracker. Once moved to
* the [DONE] state, this tracker is finished and the current step cannot be moved again.
*/
var currentStep: Step
get() = steps[stepIndex]
set(value) {
check(!hasEnded) { "Cannot rewind a progress tracker once it has ended" }
check((value === DONE && hasEnded) || !hasEnded) {
"Cannot rewind a progress tracker once it has ended"
}
if (currentStep == value) return
val index = steps.indexOf(value)
require(index != -1, { "Step ${value.label} not found in progress tracker." })
require(index != -1) { "Step ${value.label} not found in progress tracker." }
if (index < stepIndex) {
// We are going backwards: unlink and unsubscribe from any child nodes that we're rolling back
@ -144,6 +123,39 @@ class ProgressTracker(vararg steps: Step) {
}
}
init {
steps.forEach {
configureChildTrackerForStep(it)
}
this.currentStep = UNSTARTED
}
private fun configureChildTrackerForStep(it: Step) {
val childTracker = it.childProgressTracker()
if (childTracker != null) {
setChildProgressTracker(it, childTracker)
}
}
/** The zero-based index of the current step in the [steps] array (i.e. with UNSTARTED and DONE) */
var stepIndex: Int = 0
private set(value) {
field = value
}
/** The zero-bases index of the current step in a [allStepsLabels] list */
var stepsTreeIndex: Int = -1
private set(value) {
field = value
_stepsTreeIndexChanges.onNext(value)
}
/**
* Reading returns the value of steps[stepIndex], writing moves the position of the current tracker. Once moved to
* the [DONE] state, this tracker is finished and the current step cannot be moved again.
*/
/** Returns the current step, descending into children to find the deepest step we are up to. */
val currentStepRecursive: Step
get() = getChildProgressTracker(currentStep)?.currentStepRecursive ?: currentStep
@ -263,7 +275,7 @@ class ProgressTracker(vararg steps: Step) {
/**
* An observable stream of changes to the [allStepsLabels]
*/
val stepsTreeChanges: Observable<List<Pair<Int,String>>> get() = _stepsTreeChanges
val stepsTreeChanges: Observable<List<Pair<Int, String>>> get() = _stepsTreeChanges
/**
* An observable stream of changes to the [stepsTreeIndex]
@ -272,6 +284,10 @@ class ProgressTracker(vararg steps: Step) {
/** Returns true if the progress tracker has ended, either by reaching the [DONE] step or prematurely with an error */
val hasEnded: Boolean get() = _changes.hasCompleted() || _changes.hasThrowable()
companion object {
val DEFAULT_TRACKER = { ProgressTracker() }
}
}
// TODO: Expose the concept of errors.
// TODO: It'd be helpful if this class was at least partly thread safe.

View File

@ -50,11 +50,11 @@ class ProgressTrackerTest {
assertEquals(0, pt.stepIndex)
var stepNotification: ProgressTracker.Step? = null
pt.changes.subscribe { stepNotification = (it as? ProgressTracker.Change.Position)?.newStep }
assertEquals(ProgressTracker.UNSTARTED, pt.currentStep)
assertEquals(ProgressTracker.STARTING, pt.nextStep())
assertEquals(SimpleSteps.ONE, pt.nextStep())
assertEquals(1, pt.stepIndex)
assertEquals(2, pt.stepIndex)
assertEquals(SimpleSteps.ONE, stepNotification)
assertEquals(SimpleSteps.TWO, pt.nextStep())
assertEquals(SimpleSteps.THREE, pt.nextStep())
assertEquals(SimpleSteps.FOUR, pt.nextStep())
@ -87,8 +87,10 @@ class ProgressTrackerTest {
assertEquals(SimpleSteps.TWO, (stepNotification.pollFirst() as ProgressTracker.Change.Structural).parent)
assertNextStep(SimpleSteps.TWO)
assertEquals(pt2.currentStep, ProgressTracker.UNSTARTED)
assertEquals(ProgressTracker.STARTING, pt2.nextStep())
assertEquals(ChildSteps.AYY, pt2.nextStep())
assertNextStep(ChildSteps.AYY)
assertEquals((stepNotification.last as ProgressTracker.Change.Position).newStep, ChildSteps.AYY)
assertEquals(ChildSteps.BEE, pt2.nextStep())
}
@ -115,19 +117,19 @@ class ProgressTrackerTest {
// Travel tree.
pt.currentStep = SimpleSteps.ONE
assertCurrentStepsTree(0, SimpleSteps.ONE)
assertCurrentStepsTree(1, SimpleSteps.ONE)
pt.currentStep = SimpleSteps.TWO
assertCurrentStepsTree(1, SimpleSteps.TWO)
assertCurrentStepsTree(2, SimpleSteps.TWO)
pt2.currentStep = ChildSteps.BEE
assertCurrentStepsTree(3, ChildSteps.BEE)
assertCurrentStepsTree(5, ChildSteps.BEE)
pt.currentStep = SimpleSteps.THREE
assertCurrentStepsTree(5, SimpleSteps.THREE)
assertCurrentStepsTree(7, SimpleSteps.THREE)
// Assert no structure changes and proper steps propagation.
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(0, 1, 3, 5))
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 2, 5, 7))
assertThat(stepsTreeNotification).isEmpty()
}
@ -153,16 +155,16 @@ class ProgressTrackerTest {
}
pt.currentStep = SimpleSteps.ONE
assertCurrentStepsTree(0, SimpleSteps.ONE)
assertCurrentStepsTree(1, SimpleSteps.ONE)
pt.currentStep = SimpleSteps.FOUR
assertCurrentStepsTree(3, SimpleSteps.FOUR)
assertCurrentStepsTree(4, SimpleSteps.FOUR)
pt2.currentStep = ChildSteps.SEA
assertCurrentStepsTree(6, ChildSteps.SEA)
assertCurrentStepsTree(8, ChildSteps.SEA)
// Assert no structure changes and proper steps propagation.
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(0, 3, 6))
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 4, 8))
assertThat(stepsTreeNotification).isEmpty()
}
@ -189,18 +191,18 @@ class ProgressTrackerTest {
}
pt.currentStep = SimpleSteps.TWO
assertCurrentStepsTree(1, SimpleSteps.TWO)
assertCurrentStepsTree(2, SimpleSteps.TWO)
pt.currentStep = SimpleSteps.FOUR
assertCurrentStepsTree(6, SimpleSteps.FOUR)
assertCurrentStepsTree(8, SimpleSteps.FOUR)
pt.setChildProgressTracker(SimpleSteps.THREE, pt3)
assertCurrentStepsTree(9, SimpleSteps.FOUR)
assertCurrentStepsTree(12, SimpleSteps.FOUR)
// Assert no structure changes and proper steps propagation.
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 6, 9))
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(2, 8, 12))
assertThat(stepsTreeNotification).hasSize(2) // 1 change + 1 our initial state
}
@ -228,14 +230,14 @@ class ProgressTrackerTest {
pt.currentStep = SimpleSteps.TWO
pt2.currentStep = ChildSteps.SEA
pt3.currentStep = BabySteps.UNOS
assertCurrentStepsTree(4, ChildSteps.SEA)
assertCurrentStepsTree(6, ChildSteps.SEA)
pt.setChildProgressTracker(SimpleSteps.TWO, pt3)
assertCurrentStepsTree(2, BabySteps.UNOS)
assertCurrentStepsTree(4, BabySteps.UNOS)
// Assert no structure changes and proper steps propagation.
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(1, 4, 2))
assertThat(stepsIndexNotifications).containsExactlyElementsOf(listOf(2, 6, 4))
assertThat(stepsTreeNotification).hasSize(2) // 1 change + 1 our initial state.
}

View File

@ -14,6 +14,7 @@ import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
@ -205,6 +206,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
override fun run() {
logic.progressTracker?.currentStep = ProgressTracker.STARTING
logic.stateMachine = this
setLoggingContext()
@ -263,7 +265,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
processEventImmediately(
Event.EnterSubFlow(subFlow.javaClass,
createSubFlowVersion(
serviceHub.cordappProvider.getCordappForFlow(subFlow), serviceHub.myInfo.platformVersion
serviceHub.cordappProvider.getCordappForFlow(subFlow), serviceHub.myInfo.platformVersion
)
),
isDbTransactionOpenOnEntry = true,
@ -435,7 +437,7 @@ val Class<out FlowLogic<*>>.flowVersionAndInitiatingClass: Pair<Int, Class<out F
current = current.superclass
?: return found
?: throw IllegalArgumentException("$name, as a flow that initiates other flows, must be annotated with " +
"${InitiatingFlow::class.java.name}. See https://docs.corda.net/api-flows.html#flowlogic-annotations.")
"${InitiatingFlow::class.java.name}. See https://docs.corda.net/api-flows.html#flowlogic-annotations.")
}
}

View File

@ -38,7 +38,10 @@ import net.corda.testing.node.internal.*
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
import org.junit.*
import org.junit.After
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import rx.Notification
import rx.Observable
import java.time.Instant
@ -423,6 +426,7 @@ class FlowFrameworkTests {
}
assertThat(receiveFlowException.message).doesNotContain("evil bug!")
assertThat(receiveFlowSteps.get()).containsExactly(
Notification.createOnNext(ProgressTracker.STARTING),
Notification.createOnNext(ReceiveFlow.START_STEP),
Notification.createOnError(receiveFlowException)
)

View File

@ -11,6 +11,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FinalityFlow
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.StartableByService
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.internal.Emoji
@ -23,9 +24,9 @@ import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.getOrThrow
import net.corda.testing.node.internal.poll
import net.corda.testing.core.DUMMY_BANK_B_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.node.internal.poll
import java.io.InputStream
import java.net.HttpURLConnection
import java.net.URL
@ -131,6 +132,16 @@ class AttachmentDemoFlow(private val otherSide: Party,
}
}
@StartableByRPC
@StartableByService
class NoProgressTrackerShellDemo : FlowLogic<String>() {
@Suspendable
override fun call(): String {
return "You Called me!"
}
}
@Suppress("DEPRECATION")
// DOCSTART 1
fun recipient(rpc: CordaRPCOps, webPort: Int) {

View File

@ -17,7 +17,10 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.*
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.StateMachineUpdate
import net.corda.nodeapi.internal.pendingFlowsCount
import net.corda.tools.shell.utlities.ANSIProgressRenderer
import net.corda.tools.shell.utlities.StdoutANSIProgressRenderer
@ -359,11 +362,6 @@ object InteractiveShell {
errors.add("${getPrototype()}: Wrong number of arguments (${args.size} provided, ${ctor.genericParameterTypes.size} needed)")
continue
}
val flow = ctor.newInstance(*args) as FlowLogic<*>
if (flow.progressTracker == null) {
errors.add("A flow must override the progress tracker in order to be run from the shell")
continue
}
return invoke(clazz, args)
} catch (e: StringToMethodCallParser.UnparseableCallException.MissingParameter) {
errors.add("${getPrototype()}: missing parameter ${e.paramName}")