diff --git a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt index 656ed3139f..480f51852f 100644 --- a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt +++ b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt @@ -8,6 +8,7 @@ import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.USD import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.keys +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.StateMachineRunId import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps @@ -41,11 +42,14 @@ import rx.Observable class NodeMonitorModelTest : DriverBasedTest() { lateinit var aliceNode: NodeInfo + lateinit var bobNode: NodeInfo lateinit var notaryNode: NodeInfo lateinit var rpc: CordaRPCOps + lateinit var rpcBob: CordaRPCOps lateinit var stateMachineTransactionMapping: Observable lateinit var stateMachineUpdates: Observable + lateinit var stateMachineUpdatesBob: Observable lateinit var progressTracking: Observable lateinit var transactions: Observable lateinit var vaultUpdates: Observable @@ -66,7 +70,6 @@ class NodeMonitorModelTest : DriverBasedTest() { notaryNode = notaryNodeHandle.nodeInfo newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo } val monitor = NodeMonitorModel() - stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed() progressTracking = monitor.progressTracking.bufferUntilSubscribed() @@ -76,12 +79,18 @@ class NodeMonitorModelTest : DriverBasedTest() { monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) rpc = monitor.proxyObservable.value!! + + val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow() + bobNode = bobNodeHandle.nodeInfo + val monitorBob = NodeMonitorModel() + stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed() + monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) + rpcBob = monitorBob.proxyObservable.value!! runTest() } @Test fun `network map update`() { - newNode(BOB.name) newNode(CHARLIE.name) networkMapUpdates.filter { !it.node.advertisedServices.any { it.info.type.isNotary() } } .filter { !it.node.advertisedServices.any { it.info.type == NetworkMapService.type } } @@ -114,12 +123,12 @@ class NodeMonitorModelTest : DriverBasedTest() { sequence( // SNAPSHOT expect { output: Vault.Update -> - require(output.consumed.size == 0) { output.consumed.size } - require(output.produced.size == 0) { output.produced.size } + require(output.consumed.isEmpty()) { output.consumed.size } + require(output.produced.isEmpty()) { output.produced.size } }, // ISSUE expect { output: Vault.Update -> - require(output.consumed.size == 0) { output.consumed.size } + require(output.consumed.isEmpty()) { output.consumed.size } require(output.produced.size == 1) { output.produced.size } } ) @@ -129,7 +138,7 @@ class NodeMonitorModelTest : DriverBasedTest() { @Test fun `cash issue and move`() { rpc.startFlow(::CashIssueFlow, 100.DOLLARS, OpaqueBytes.of(1), aliceNode.legalIdentity, notaryNode.notaryIdentity).returnValue.getOrThrow() - rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, aliceNode.legalIdentity).returnValue.getOrThrow() + rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, bobNode.legalIdentity).returnValue.getOrThrow() var issueSmId: StateMachineRunId? = null var moveSmId: StateMachineRunId? = null @@ -140,6 +149,8 @@ class NodeMonitorModelTest : DriverBasedTest() { // ISSUE expect { add: StateMachineUpdate.Added -> issueSmId = add.id + val initiator = add.stateMachineInfo.initiator + require(initiator is FlowInitiator.RPC && initiator.username == "user1") }, expect { remove: StateMachineUpdate.Removed -> require(remove.id == issueSmId) @@ -147,6 +158,8 @@ class NodeMonitorModelTest : DriverBasedTest() { // MOVE expect { add: StateMachineUpdate.Added -> moveSmId = add.id + val initiator = add.stateMachineInfo.initiator + require(initiator is FlowInitiator.RPC && initiator.username == "user1") }, expect { remove: StateMachineUpdate.Removed -> require(remove.id == moveSmId) @@ -154,6 +167,16 @@ class NodeMonitorModelTest : DriverBasedTest() { ) } + stateMachineUpdatesBob.expectEvents { + sequence( + // MOVE + expect { add: StateMachineUpdate.Added -> + val initiator = add.stateMachineInfo.initiator + require(initiator is FlowInitiator.Peer && initiator.party.name == aliceNode.legalIdentity.name) + } + ) + } + transactions.expectEvents { sequence( // ISSUE @@ -184,18 +207,18 @@ class NodeMonitorModelTest : DriverBasedTest() { sequence( // SNAPSHOT expect { output: Vault.Update -> - require(output.consumed.size == 0) { output.consumed.size } - require(output.produced.size == 0) { output.produced.size } + require(output.consumed.isEmpty()) { output.consumed.size } + require(output.produced.isEmpty()) { output.produced.size } }, // ISSUE expect { update -> - require(update.consumed.size == 0) { update.consumed.size } + require(update.consumed.isEmpty()) { update.consumed.size } require(update.produced.size == 1) { update.produced.size } }, // MOVE expect { update -> require(update.consumed.size == 1) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } + require(update.produced.isEmpty()) { update.produced.size } } ) } diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt index 8ea64e56ab..b7f51b1b55 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt @@ -1,10 +1,13 @@ package net.corda.client.rpc import net.corda.core.contracts.DOLLARS +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowException import net.corda.core.getOrThrow import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowProgressHandle +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.StateMachineUpdate import net.corda.core.messaging.startFlow import net.corda.core.messaging.startTrackedFlow import net.corda.core.node.services.ServiceInfo @@ -66,10 +69,7 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `close-send deadlock and premature shutdown on empty observable`() { - println("Starting client") - client.start(rpcUser.username, rpcUser.password) - println("Creating proxy") - val proxy = client.proxy() + val proxy = createRpcProxy(rpcUser.username, rpcUser.password) println("Starting flow") val flowHandle = proxy.startTrackedFlow( ::CashIssueFlow, @@ -104,11 +104,7 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `get cash balances`() { - println("Starting client") - client.start(rpcUser.username, rpcUser.password) - println("Creating proxy") - val proxy = client.proxy() - + val proxy = createRpcProxy(rpcUser.username, rpcUser.password) val startCash = proxy.getCashBalances() assertTrue(startCash.isEmpty(), "Should not start with any cash") @@ -125,4 +121,38 @@ class CordaRPCClientTest : NodeBasedTest() { assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD"))) } + @Test + fun `flow initiator via RPC`() { + val proxy = createRpcProxy(rpcUser.username, rpcUser.password) + val smUpdates = proxy.stateMachinesAndUpdates() + var countRpcFlows = 0 + var countShellFlows = 0 + smUpdates.second.subscribe { + if (it is StateMachineUpdate.Added) { + val initiator = it.stateMachineInfo.initiator + if (initiator is FlowInitiator.RPC) + countRpcFlows++ + if (initiator is FlowInitiator.Shell) + countShellFlows++ + } + } + val nodeIdentity = node.info.legalIdentity + node.services.startFlow(CashIssueFlow(2000.DOLLARS, OpaqueBytes.of(0), nodeIdentity, nodeIdentity), FlowInitiator.Shell).resultFuture.getOrThrow() + proxy.startFlow(::CashIssueFlow, + 123.DOLLARS, OpaqueBytes.of(0), + nodeIdentity, nodeIdentity + ).returnValue.getOrThrow() + proxy.startFlowDynamic(CashIssueFlow::class.java, + 1000.DOLLARS, OpaqueBytes.of(0), + nodeIdentity, nodeIdentity).returnValue.getOrThrow() + assertEquals(2, countRpcFlows) + assertEquals(1, countShellFlows) + } + + private fun createRpcProxy(username: String, password: String): CordaRPCOps { + println("Starting client") + client.start(username, password) + println("Creating proxy") + return client.proxy() + } } diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index d0f083a16d..55c8bac354 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -186,7 +186,7 @@ abstract class FlowLogic { return stateMachine.waitForLedgerCommit(hash, this) } - //////////////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////////////// private var _stateMachine: FlowStateMachine<*>? = null /** diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt index 2ba7c2d021..b23a79c125 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt @@ -2,6 +2,7 @@ package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.contracts.ScheduledStateRef import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash import net.corda.core.messaging.FlowHandle @@ -12,6 +13,23 @@ import net.corda.core.utilities.UntrustworthyData import org.slf4j.Logger import java.util.* +/** + * FlowInitiator holds information on who started the flow. We have different ways of doing that: via RPC [FlowInitiator.RPC], + * communication started by peer node [FlowInitiator.Peer], scheduled flows [FlowInitiator.Scheduled] + * or manual [FlowInitiator.Manual]. The last case is for all flows started in tests, shell etc. It was added + * because we can start flow directly using [StateMachineManager.add] or [ServiceHubInternal.startFlow]. + */ +@CordaSerializable +sealed class FlowInitiator { + /** Started using [CordaRPCOps.startFlowDynamic]. */ + data class RPC(val username: String) : FlowInitiator() + /** Started when we get new session initiation request. */ + data class Peer(val party: Party) : FlowInitiator() + /** Started as scheduled activity. */ + class Scheduled(val scheduledState: ScheduledStateRef) : FlowInitiator() + object Shell : FlowInitiator() // TODO When proper ssh access enabled, add username/use RPC? +} + /** * A unique identifier for a single state machine run, valid across node restarts. Note that a single run always * has at least one flow, but that flow may also invoke sub-flows: they all share the same run id. @@ -48,4 +66,5 @@ interface FlowStateMachine { val logger: Logger val id: StateMachineRunId val resultFuture: ListenableFuture + val flowInitiator: FlowInitiator } diff --git a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt index 325c781c4c..d75cc26e20 100644 --- a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt +++ b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt @@ -7,6 +7,7 @@ import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.UpgradedContract import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.node.NodeInfo @@ -26,6 +27,7 @@ import java.util.* data class StateMachineInfo( val id: StateMachineRunId, val flowLogicClassName: String, + val initiator: FlowInitiator, val progressTrackerStepAndUpdates: Pair>? ) { override fun toString(): String = "${javaClass.simpleName}($id, $flowLogicClassName)" diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index 208df63775..1064d537ed 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -2,6 +2,7 @@ package net.corda.core.node import net.corda.core.contracts.* import net.corda.core.crypto.keys +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.messaging.MessagingService @@ -86,14 +87,6 @@ interface ServiceHub : ServicesForResolution { return definingTx.tx.outRef(ref.index) } - /** - * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. - * Note that you must be on the server thread to call this method. - * - * @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args]. - */ - fun invokeFlowAsync(logicType: Class>, vararg args: Any?): FlowStateMachine - /** * Helper property to shorten code for fetching the Node's KeyPair associated with the * public legalIdentity Party from the key management service. diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 21d4ff0982..b9550ad036 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -10,6 +10,7 @@ import net.corda.core.contracts.Amount import net.corda.core.contracts.PartyAndReference import net.corda.core.crypto.Party import net.corda.core.crypto.X509Utilities +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowStateMachine @@ -127,8 +128,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, override val monitoringService: MonitoringService = MonitoringService(MetricRegistry()) override val flowLogicRefFactory: FlowLogicRefFactory get() = flowLogicFactory - override fun startFlow(logic: FlowLogic): FlowStateMachine { - return serverThread.fetchFrom { smm.add(logic) } + override fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachine { + return serverThread.fetchFrom { smm.add(logic, flowInitiator) } } override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index b999e0af7b..c0e567c6f8 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -5,6 +5,7 @@ import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.UpgradedContract import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.messaging.* @@ -20,6 +21,7 @@ import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.utilities.AddOrRemove import net.corda.node.utilities.transaction import org.bouncycastle.asn1.x500.X500Name +import net.corda.nodeapi.CURRENT_RPC_USER import org.jetbrains.exposed.sql.Database import rx.Observable import java.io.InputStream @@ -61,7 +63,7 @@ class CordaRPCOpsImpl( return database.transaction { val (allStateMachines, changes) = smm.track() Pair( - allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic) }, + allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic, it.flowInitiator) }, changes.map { stateMachineUpdateFromStateMachineChange(it) } ) } @@ -98,13 +100,15 @@ class CordaRPCOpsImpl( // TODO: Check that this flow is annotated as being intended for RPC invocation override fun startTrackedFlowDynamic(logicType: Class>, vararg args: Any?): FlowProgressHandle { requirePermission(startFlowPermission(logicType)) - return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = true) as FlowProgressHandle + val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username) + return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = true) as FlowProgressHandle } // TODO: Check that this flow is annotated as being intended for RPC invocation override fun startFlowDynamic(logicType: Class>, vararg args: Any?): FlowHandle { requirePermission(startFlowPermission(logicType)) - return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = false) + val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username) + return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = false) } override fun attachmentExists(id: SecureHash): Boolean { @@ -147,13 +151,13 @@ class CordaRPCOpsImpl( override fun registeredFlows(): List = services.flowLogicRefFactory.flowWhitelist.keys.sorted() companion object { - private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>): StateMachineInfo { - return StateMachineInfo(id, flowLogic.javaClass.name, flowLogic.track()) + private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>, flowInitiator: FlowInitiator): StateMachineInfo { + return StateMachineInfo(id, flowLogic.javaClass.name, flowInitiator, flowLogic.track()) } private fun stateMachineUpdateFromStateMachineChange(change: StateMachineManager.Change): StateMachineUpdate { return when (change.addOrRemove) { - AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic)) + AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic, change.flowInitiator)) AddOrRemove.REMOVE -> StateMachineUpdate.Removed(change.id) } } diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 7f58fb3c87..6be557f461 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -1,6 +1,8 @@ package net.corda.node.services.api +import com.google.common.annotations.VisibleForTesting import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowStateMachine @@ -64,14 +66,34 @@ abstract class ServiceHubInternal : PluginServiceHub { } /** - * Starts an already constructed flow. Note that you must be on the server thread to call this method. + * Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator] + * defaults to [FlowInitiator.RPC] with username "Only For Testing". */ - abstract fun startFlow(logic: FlowLogic): FlowStateMachine + // TODO Move it to test utils. + @VisibleForTesting + fun startFlow(logic: FlowLogic): FlowStateMachine = startFlow(logic, FlowInitiator.RPC("Only For Testing")) - override fun invokeFlowAsync(logicType: Class>, vararg args: Any?): FlowStateMachine { + /** + * Starts an already constructed flow. Note that you must be on the server thread to call this method. + * @param flowInitiator indicates who started the flow, see: [FlowInitiator]. + */ + abstract fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachine + + + /** + * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. + * Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started, + * See: [FlowInitiator]. + * + * @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args]. + */ + fun invokeFlowAsync( + logicType: Class>, + flowInitiator: FlowInitiator, + vararg args: Any?): FlowStateMachine { val logicRef = flowLogicRefFactory.create(logicType, *args) @Suppress("UNCHECKED_CAST") val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic - return startFlow(logic) + return startFlow(logic, flowInitiator) } } diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index 8e5ee70ad1..881e89ba0f 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -7,6 +7,7 @@ import net.corda.core.contracts.SchedulableState import net.corda.core.contracts.ScheduledActivity import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.StateRef +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.node.services.SchedulerService @@ -158,7 +159,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal, } private fun onTimeReached(scheduledState: ScheduledStateRef) { - services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService)) + services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService), FlowInitiator.Scheduled(scheduledState)) } class RunScheduled(val scheduledState: ScheduledStateRef, val scheduler: NodeSchedulerService) : FlowLogic() { @@ -167,7 +168,6 @@ class NodeSchedulerService(private val services: ServiceHubInternal, fun tracker() = ProgressTracker(RUNNING) } - override val progressTracker = tracker() @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index ef38d4116b..00a89df830 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -11,6 +11,7 @@ import net.corda.core.abbreviate import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.StateMachineRunId @@ -39,7 +40,8 @@ import java.util.concurrent.TimeUnit class FlowStateMachineImpl(override val id: StateMachineRunId, val logic: FlowLogic, - scheduler: FiberScheduler) : Fiber(id.toString(), scheduler), FlowStateMachine { + scheduler: FiberScheduler, + override val flowInitiator: FlowInitiator) : Fiber(id.toString(), scheduler), FlowStateMachine { companion object { // Used to work around a small limitation in Quasar. private val QUASAR_UNBLOCKER = run { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 37dec2e474..d3303a7e65 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -18,6 +18,7 @@ import net.corda.core.bufferUntilSubscribed import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash import net.corda.core.crypto.commonName +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine @@ -113,7 +114,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, data class Change( val logic: FlowLogic<*>, val addOrRemove: AddOrRemove, - val id: StateMachineRunId + val id: StateMachineRunId, + val flowInitiator: FlowInitiator ) // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines @@ -125,7 +127,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val fibersWaitingForLedgerCommit = HashMultimap.create>()!! fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) { - changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id)) + changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id, fiber.flowInitiator)) } } @@ -359,7 +361,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val session = try { val flow = flowFactory(sender) - val fiber = createFiber(flow) + val fiber = createFiber(flow, FlowInitiator.Peer(sender)) val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId)) if (sessionInit.firstPayload != null) { session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) @@ -398,9 +400,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun quasarKryo(): KryoPool = quasarKryoPool - private fun createFiber(logic: FlowLogic): FlowStateMachineImpl { + private fun createFiber(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachineImpl { val id = StateMachineRunId.createRandom() - return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) } + return FlowStateMachineImpl(id, logic, scheduler, flowInitiator).apply { initFiber(this) } } private fun initFiber(fiber: FlowStateMachineImpl<*>) { @@ -471,7 +473,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, * * Note that you must be on the [executor] thread. */ - fun add(logic: FlowLogic): FlowStateMachine { + fun add(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachine { // TODO: Check that logic has @Suspendable on its call method. executor.checkOnThread() // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait @@ -479,7 +481,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, // unable to acquire the table lock and move forward till the calling transaction finishes. // Committing in line here on a fresh context ensure we can progress. val fiber = database.isolatedTransaction { - val fiber = createFiber(logic) + val fiber = createFiber(logic, flowInitiator) updateCheckpoint(fiber) fiber } diff --git a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt index daf7d2c4b9..cc7d4b4f31 100644 --- a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt +++ b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt @@ -10,6 +10,7 @@ import com.google.common.io.Closeables import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import net.corda.core.* +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.messaging.CordaRPCOps @@ -221,8 +222,9 @@ object InteractiveShell { if (!FlowLogic::class.java.isAssignableFrom(clazz)) throw IllegalStateException("Found a non-FlowLogic class in the whitelist? $clazz") try { + // TODO Flow invocation should use startFlowDynamic. @Suppress("UNCHECKED_CAST") - val fsm = runFlowFromString({ node.services.startFlow(it) }, inputData, clazz as Class>) + val fsm = runFlowFromString({ node.services.startFlow(it, FlowInitiator.Shell) }, inputData, clazz as Class>) // Show the progress tracker on the console until the flow completes or is interrupted with a // Ctrl-C keypress. val latch = CountDownLatch(1) diff --git a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt index 411559c2b7..970992b71b 100644 --- a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt +++ b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt @@ -5,6 +5,7 @@ import com.google.common.util.concurrent.ListenableFuture import net.corda.core.contracts.Amount import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowInitiator import net.corda.core.crypto.X509Utilities import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine @@ -97,5 +98,7 @@ class InteractiveShellTest { get() = throw UnsupportedOperationException() override val resultFuture: ListenableFuture get() = throw UnsupportedOperationException() + override val flowInitiator: FlowInitiator + get() = throw UnsupportedOperationException() } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt b/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt index 20e293ecdb..dededeada0 100644 --- a/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt +++ b/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt @@ -2,6 +2,7 @@ package net.corda.node.services import com.codahale.metrics.MetricRegistry import net.corda.core.crypto.Party +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowStateMachine @@ -81,8 +82,8 @@ open class MockServiceHubInternal( override fun recordTransactions(txs: Iterable) = recordTransactionsInternal(txStorageService, txs) - override fun startFlow(logic: FlowLogic): FlowStateMachine { - return smm.executor.fetchFrom { smm.add(logic) } + override fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachine { + return smm.executor.fetchFrom { smm.add(logic, flowInitiator) } } override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { diff --git a/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt b/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt index 03f2267da6..1a40b04666 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt @@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.* import net.corda.core.crypto.Party import net.corda.core.crypto.containsAny +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.node.CordaPluginRegistry @@ -13,6 +14,7 @@ import net.corda.core.utilities.DUMMY_NOTARY import net.corda.flows.FinalityFlow import net.corda.node.services.network.NetworkMapService import net.corda.node.services.transactions.ValidatingNotaryService +import net.corda.node.utilities.AddOrRemove import net.corda.node.utilities.transaction import net.corda.testing.node.MockNetwork import org.junit.After @@ -112,6 +114,15 @@ class ScheduledFlowTests { @Test fun `create and run scheduled flow then wait for result`() { + val stateMachines = nodeA.smm.track() + var countScheduledFlows = 0 + stateMachines.second.subscribe { + if (it.addOrRemove == AddOrRemove.ADD) { + val initiator = it.flowInitiator + if (initiator is FlowInitiator.Scheduled) + countScheduledFlows++ + } + } nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity)) net.waitQuiescent() val stateFromA = nodeA.database.transaction { @@ -120,6 +131,7 @@ class ScheduledFlowTests { val stateFromB = nodeB.database.transaction { nodeB.services.vaultService.linearHeadsOfType().values.first() } + assertEquals(1, countScheduledFlows) assertEquals(stateFromA, stateFromB, "Must be same copy on both nodes") assertTrue("Must be processed", stateFromB.state.data.processed) } diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt index e7f84d35d0..7e0bd79cef 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -3,6 +3,7 @@ package net.corda.testing.node import net.corda.core.contracts.Attachment import net.corda.core.contracts.PartyAndReference import net.corda.core.crypto.* +import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.StateMachineRunId @@ -46,10 +47,6 @@ import javax.annotation.concurrent.ThreadSafe * building chains of transactions and verifying them. It isn't sufficient for testing flows however. */ open class MockServices(val key: KeyPair = generateKeyPair()) : ServiceHub { - override fun invokeFlowAsync(logicType: Class>, vararg args: Any?): FlowStateMachine { - throw UnsupportedOperationException("not implemented") - } - override fun recordTransactions(txs: Iterable) { txs.forEach { storageService.stateMachineRecordedTransactionMapping.addMapping(StateMachineRunId.createRandom(), it.id)