Add information on who started flow on a node. (#549)

* Add information on who started flow on a node with name where possible.
Add sealed class holding information on different ways of starting a flow: RPC, peer, shell, scheduled.

* Remove invokeFlowAsync from ServiceHub, move it to ServiceHubInternal.
We shouldn't be able to start new state machines from inside flows.
This commit is contained in:
Katarzyna Streich 2017-04-24 17:05:51 +01:00 committed by GitHub
parent af7f5ef0d7
commit c1b7b1cb75
17 changed files with 170 additions and 57 deletions

View File

@ -8,6 +8,7 @@ import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.USD
import net.corda.core.crypto.isFulfilledBy
import net.corda.core.crypto.keys
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.StateMachineRunId
import net.corda.core.getOrThrow
import net.corda.core.messaging.CordaRPCOps
@ -41,11 +42,14 @@ import rx.Observable
class NodeMonitorModelTest : DriverBasedTest() {
lateinit var aliceNode: NodeInfo
lateinit var bobNode: NodeInfo
lateinit var notaryNode: NodeInfo
lateinit var rpc: CordaRPCOps
lateinit var rpcBob: CordaRPCOps
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
lateinit var stateMachineUpdatesBob: Observable<StateMachineUpdate>
lateinit var progressTracking: Observable<ProgressTrackingEvent>
lateinit var transactions: Observable<SignedTransaction>
lateinit var vaultUpdates: Observable<Vault.Update>
@ -66,7 +70,6 @@ class NodeMonitorModelTest : DriverBasedTest() {
notaryNode = notaryNodeHandle.nodeInfo
newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo }
val monitor = NodeMonitorModel()
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed()
progressTracking = monitor.progressTracking.bufferUntilSubscribed()
@ -76,12 +79,18 @@ class NodeMonitorModelTest : DriverBasedTest() {
monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password)
rpc = monitor.proxyObservable.value!!
val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow()
bobNode = bobNodeHandle.nodeInfo
val monitorBob = NodeMonitorModel()
stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed()
monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password)
rpcBob = monitorBob.proxyObservable.value!!
runTest()
}
@Test
fun `network map update`() {
newNode(BOB.name)
newNode(CHARLIE.name)
networkMapUpdates.filter { !it.node.advertisedServices.any { it.info.type.isNotary() } }
.filter { !it.node.advertisedServices.any { it.info.type == NetworkMapService.type } }
@ -114,12 +123,12 @@ class NodeMonitorModelTest : DriverBasedTest() {
sequence(
// SNAPSHOT
expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size }
require(output.produced.size == 0) { output.produced.size }
require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.isEmpty()) { output.produced.size }
},
// ISSUE
expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size }
require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.size == 1) { output.produced.size }
}
)
@ -129,7 +138,7 @@ class NodeMonitorModelTest : DriverBasedTest() {
@Test
fun `cash issue and move`() {
rpc.startFlow(::CashIssueFlow, 100.DOLLARS, OpaqueBytes.of(1), aliceNode.legalIdentity, notaryNode.notaryIdentity).returnValue.getOrThrow()
rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, aliceNode.legalIdentity).returnValue.getOrThrow()
rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, bobNode.legalIdentity).returnValue.getOrThrow()
var issueSmId: StateMachineRunId? = null
var moveSmId: StateMachineRunId? = null
@ -140,6 +149,8 @@ class NodeMonitorModelTest : DriverBasedTest() {
// ISSUE
expect { add: StateMachineUpdate.Added ->
issueSmId = add.id
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.RPC && initiator.username == "user1")
},
expect { remove: StateMachineUpdate.Removed ->
require(remove.id == issueSmId)
@ -147,6 +158,8 @@ class NodeMonitorModelTest : DriverBasedTest() {
// MOVE
expect { add: StateMachineUpdate.Added ->
moveSmId = add.id
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.RPC && initiator.username == "user1")
},
expect { remove: StateMachineUpdate.Removed ->
require(remove.id == moveSmId)
@ -154,6 +167,16 @@ class NodeMonitorModelTest : DriverBasedTest() {
)
}
stateMachineUpdatesBob.expectEvents {
sequence(
// MOVE
expect { add: StateMachineUpdate.Added ->
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.Peer && initiator.party.name == aliceNode.legalIdentity.name)
}
)
}
transactions.expectEvents {
sequence(
// ISSUE
@ -184,18 +207,18 @@ class NodeMonitorModelTest : DriverBasedTest() {
sequence(
// SNAPSHOT
expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size }
require(output.produced.size == 0) { output.produced.size }
require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.isEmpty()) { output.produced.size }
},
// ISSUE
expect { update ->
require(update.consumed.size == 0) { update.consumed.size }
require(update.consumed.isEmpty()) { update.consumed.size }
require(update.produced.size == 1) { update.produced.size }
},
// MOVE
expect { update ->
require(update.consumed.size == 1) { update.consumed.size }
require(update.produced.size == 1) { update.produced.size }
require(update.produced.isEmpty()) { update.produced.size }
}
)
}

View File

@ -1,10 +1,13 @@
package net.corda.client.rpc
import net.corda.core.contracts.DOLLARS
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException
import net.corda.core.getOrThrow
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.startFlow
import net.corda.core.messaging.startTrackedFlow
import net.corda.core.node.services.ServiceInfo
@ -66,10 +69,7 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `close-send deadlock and premature shutdown on empty observable`() {
println("Starting client")
client.start(rpcUser.username, rpcUser.password)
println("Creating proxy")
val proxy = client.proxy()
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
println("Starting flow")
val flowHandle = proxy.startTrackedFlow(
::CashIssueFlow,
@ -104,11 +104,7 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `get cash balances`() {
println("Starting client")
client.start(rpcUser.username, rpcUser.password)
println("Creating proxy")
val proxy = client.proxy()
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
val startCash = proxy.getCashBalances()
assertTrue(startCash.isEmpty(), "Should not start with any cash")
@ -125,4 +121,38 @@ class CordaRPCClientTest : NodeBasedTest() {
assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD")))
}
@Test
fun `flow initiator via RPC`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
val smUpdates = proxy.stateMachinesAndUpdates()
var countRpcFlows = 0
var countShellFlows = 0
smUpdates.second.subscribe {
if (it is StateMachineUpdate.Added) {
val initiator = it.stateMachineInfo.initiator
if (initiator is FlowInitiator.RPC)
countRpcFlows++
if (initiator is FlowInitiator.Shell)
countShellFlows++
}
}
val nodeIdentity = node.info.legalIdentity
node.services.startFlow(CashIssueFlow(2000.DOLLARS, OpaqueBytes.of(0), nodeIdentity, nodeIdentity), FlowInitiator.Shell).resultFuture.getOrThrow()
proxy.startFlow(::CashIssueFlow,
123.DOLLARS, OpaqueBytes.of(0),
nodeIdentity, nodeIdentity
).returnValue.getOrThrow()
proxy.startFlowDynamic(CashIssueFlow::class.java,
1000.DOLLARS, OpaqueBytes.of(0),
nodeIdentity, nodeIdentity).returnValue.getOrThrow()
assertEquals(2, countRpcFlows)
assertEquals(1, countShellFlows)
}
private fun createRpcProxy(username: String, password: String): CordaRPCOps {
println("Starting client")
client.start(username, password)
println("Creating proxy")
return client.proxy()
}
}

View File

@ -186,7 +186,7 @@ abstract class FlowLogic<out T> {
return stateMachine.waitForLedgerCommit(hash, this)
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////////
private var _stateMachine: FlowStateMachine<*>? = null
/**

View File

@ -2,6 +2,7 @@ package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.messaging.FlowHandle
@ -12,6 +13,23 @@ import net.corda.core.utilities.UntrustworthyData
import org.slf4j.Logger
import java.util.*
/**
* FlowInitiator holds information on who started the flow. We have different ways of doing that: via RPC [FlowInitiator.RPC],
* communication started by peer node [FlowInitiator.Peer], scheduled flows [FlowInitiator.Scheduled]
* or manual [FlowInitiator.Manual]. The last case is for all flows started in tests, shell etc. It was added
* because we can start flow directly using [StateMachineManager.add] or [ServiceHubInternal.startFlow].
*/
@CordaSerializable
sealed class FlowInitiator {
/** Started using [CordaRPCOps.startFlowDynamic]. */
data class RPC(val username: String) : FlowInitiator()
/** Started when we get new session initiation request. */
data class Peer(val party: Party) : FlowInitiator()
/** Started as scheduled activity. */
class Scheduled(val scheduledState: ScheduledStateRef) : FlowInitiator()
object Shell : FlowInitiator() // TODO When proper ssh access enabled, add username/use RPC?
}
/**
* A unique identifier for a single state machine run, valid across node restarts. Note that a single run always
* has at least one flow, but that flow may also invoke sub-flows: they all share the same run id.
@ -48,4 +66,5 @@ interface FlowStateMachine<R> {
val logger: Logger
val id: StateMachineRunId
val resultFuture: ListenableFuture<R>
val flowInitiator: FlowInitiator
}

View File

@ -7,6 +7,7 @@ import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.UpgradedContract
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.node.NodeInfo
@ -26,6 +27,7 @@ import java.util.*
data class StateMachineInfo(
val id: StateMachineRunId,
val flowLogicClassName: String,
val initiator: FlowInitiator,
val progressTrackerStepAndUpdates: Pair<String, Observable<String>>?
) {
override fun toString(): String = "${javaClass.simpleName}($id, $flowLogicClassName)"

View File

@ -2,6 +2,7 @@ package net.corda.core.node
import net.corda.core.contracts.*
import net.corda.core.crypto.keys
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.messaging.MessagingService
@ -86,14 +87,6 @@ interface ServiceHub : ServicesForResolution {
return definingTx.tx.outRef<T>(ref.index)
}
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
* Note that you must be on the server thread to call this method.
*
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args].
*/
fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T>
/**
* Helper property to shorten code for fetching the Node's KeyPair associated with the
* public legalIdentity Party from the key management service.

View File

@ -10,6 +10,7 @@ import net.corda.core.contracts.Amount
import net.corda.core.contracts.PartyAndReference
import net.corda.core.crypto.Party
import net.corda.core.crypto.X509Utilities
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine
@ -127,8 +128,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
override val monitoringService: MonitoringService = MonitoringService(MetricRegistry())
override val flowLogicRefFactory: FlowLogicRefFactory get() = flowLogicFactory
override fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> {
return serverThread.fetchFrom { smm.add(logic) }
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
return serverThread.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {

View File

@ -5,6 +5,7 @@ import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.UpgradedContract
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.*
@ -20,6 +21,7 @@ import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.utilities.AddOrRemove
import net.corda.node.utilities.transaction
import org.bouncycastle.asn1.x500.X500Name
import net.corda.nodeapi.CURRENT_RPC_USER
import org.jetbrains.exposed.sql.Database
import rx.Observable
import java.io.InputStream
@ -61,7 +63,7 @@ class CordaRPCOpsImpl(
return database.transaction {
val (allStateMachines, changes) = smm.track()
Pair(
allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic) },
allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic, it.flowInitiator) },
changes.map { stateMachineUpdateFromStateMachineChange(it) }
)
}
@ -98,13 +100,15 @@ class CordaRPCOpsImpl(
// TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> {
requirePermission(startFlowPermission(logicType))
return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = true) as FlowProgressHandle<T>
val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = true) as FlowProgressHandle<T>
}
// TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> {
requirePermission(startFlowPermission(logicType))
return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = false)
val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = false)
}
override fun attachmentExists(id: SecureHash): Boolean {
@ -147,13 +151,13 @@ class CordaRPCOpsImpl(
override fun registeredFlows(): List<String> = services.flowLogicRefFactory.flowWhitelist.keys.sorted()
companion object {
private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>): StateMachineInfo {
return StateMachineInfo(id, flowLogic.javaClass.name, flowLogic.track())
private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>, flowInitiator: FlowInitiator): StateMachineInfo {
return StateMachineInfo(id, flowLogic.javaClass.name, flowInitiator, flowLogic.track())
}
private fun stateMachineUpdateFromStateMachineChange(change: StateMachineManager.Change): StateMachineUpdate {
return when (change.addOrRemove) {
AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic))
AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic, change.flowInitiator))
AddOrRemove.REMOVE -> StateMachineUpdate.Removed(change.id)
}
}

View File

@ -1,6 +1,8 @@
package net.corda.node.services.api
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine
@ -64,14 +66,34 @@ abstract class ServiceHubInternal : PluginServiceHub {
}
/**
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
* Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator]
* defaults to [FlowInitiator.RPC] with username "Only For Testing".
*/
abstract fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T>
// TODO Move it to test utils.
@VisibleForTesting
fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> = startFlow(logic, FlowInitiator.RPC("Only For Testing"))
override fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T> {
/**
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
* @param flowInitiator indicates who started the flow, see: [FlowInitiator].
*/
abstract fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T>
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
* Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started,
* See: [FlowInitiator].
*
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args].
*/
fun <T : Any> invokeFlowAsync(
logicType: Class<out FlowLogic<T>>,
flowInitiator: FlowInitiator,
vararg args: Any?): FlowStateMachine<T> {
val logicRef = flowLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST")
val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T>
return startFlow(logic)
return startFlow(logic, flowInitiator)
}
}

View File

@ -7,6 +7,7 @@ import net.corda.core.contracts.SchedulableState
import net.corda.core.contracts.ScheduledActivity
import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.node.services.SchedulerService
@ -158,7 +159,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
}
private fun onTimeReached(scheduledState: ScheduledStateRef) {
services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService))
services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService), FlowInitiator.Scheduled(scheduledState))
}
class RunScheduled(val scheduledState: ScheduledStateRef, val scheduler: NodeSchedulerService) : FlowLogic<Unit>() {
@ -167,7 +168,6 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
fun tracker() = ProgressTracker(RUNNING)
}
override val progressTracker = tracker()
@Suspendable

View File

@ -11,6 +11,7 @@ import net.corda.core.abbreviate
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
@ -39,7 +40,8 @@ import java.util.concurrent.TimeUnit
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val logic: FlowLogic<R>,
scheduler: FiberScheduler) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
scheduler: FiberScheduler,
override val flowInitiator: FlowInitiator) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {

View File

@ -18,6 +18,7 @@ import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.commonName
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
@ -113,7 +114,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
data class Change(
val logic: FlowLogic<*>,
val addOrRemove: AddOrRemove,
val id: StateMachineRunId
val id: StateMachineRunId,
val flowInitiator: FlowInitiator
)
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
@ -125,7 +127,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val fibersWaitingForLedgerCommit = HashMultimap.create<SecureHash, FlowStateMachineImpl<*>>()!!
fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id))
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id, fiber.flowInitiator))
}
}
@ -359,7 +361,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val session = try {
val flow = flowFactory(sender)
val fiber = createFiber(flow)
val fiber = createFiber(flow, FlowInitiator.Peer(sender))
val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId))
if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
@ -398,9 +400,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun quasarKryo(): KryoPool = quasarKryoPool
private fun <T> createFiber(logic: FlowLogic<T>): FlowStateMachineImpl<T> {
private fun <T> createFiber(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachineImpl<T> {
val id = StateMachineRunId.createRandom()
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }
return FlowStateMachineImpl(id, logic, scheduler, flowInitiator).apply { initFiber(this) }
}
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
@ -471,7 +473,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
*
* Note that you must be on the [executor] thread.
*/
fun <T> add(logic: FlowLogic<T>): FlowStateMachine<T> {
fun <T> add(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
// TODO: Check that logic has @Suspendable on its call method.
executor.checkOnThread()
// We swap out the parent transaction context as using this frequently leads to a deadlock as we wait
@ -479,7 +481,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// unable to acquire the table lock and move forward till the calling transaction finishes.
// Committing in line here on a fresh context ensure we can progress.
val fiber = database.isolatedTransaction {
val fiber = createFiber(logic)
val fiber = createFiber(logic, flowInitiator)
updateCheckpoint(fiber)
fiber
}

View File

@ -10,6 +10,7 @@ import com.google.common.io.Closeables
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.messaging.CordaRPCOps
@ -221,8 +222,9 @@ object InteractiveShell {
if (!FlowLogic::class.java.isAssignableFrom(clazz))
throw IllegalStateException("Found a non-FlowLogic class in the whitelist? $clazz")
try {
// TODO Flow invocation should use startFlowDynamic.
@Suppress("UNCHECKED_CAST")
val fsm = runFlowFromString({ node.services.startFlow(it) }, inputData, clazz as Class<FlowLogic<*>>)
val fsm = runFlowFromString({ node.services.startFlow(it, FlowInitiator.Shell) }, inputData, clazz as Class<FlowLogic<*>>)
// Show the progress tracker on the console until the flow completes or is interrupted with a
// Ctrl-C keypress.
val latch = CountDownLatch(1)

View File

@ -5,6 +5,7 @@ import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.Amount
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.crypto.X509Utilities
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
@ -97,5 +98,7 @@ class InteractiveShellTest {
get() = throw UnsupportedOperationException()
override val resultFuture: ListenableFuture<Any?>
get() = throw UnsupportedOperationException()
override val flowInitiator: FlowInitiator
get() = throw UnsupportedOperationException()
}
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services
import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine
@ -81,8 +82,8 @@ open class MockServiceHubInternal(
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> {
return smm.executor.fetchFrom { smm.add(logic) }
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
return smm.executor.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {

View File

@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.*
import net.corda.core.crypto.Party
import net.corda.core.crypto.containsAny
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.node.CordaPluginRegistry
@ -13,6 +14,7 @@ import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.FinalityFlow
import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.node.utilities.AddOrRemove
import net.corda.node.utilities.transaction
import net.corda.testing.node.MockNetwork
import org.junit.After
@ -112,6 +114,15 @@ class ScheduledFlowTests {
@Test
fun `create and run scheduled flow then wait for result`() {
val stateMachines = nodeA.smm.track()
var countScheduledFlows = 0
stateMachines.second.subscribe {
if (it.addOrRemove == AddOrRemove.ADD) {
val initiator = it.flowInitiator
if (initiator is FlowInitiator.Scheduled)
countScheduledFlows++
}
}
nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity))
net.waitQuiescent()
val stateFromA = nodeA.database.transaction {
@ -120,6 +131,7 @@ class ScheduledFlowTests {
val stateFromB = nodeB.database.transaction {
nodeB.services.vaultService.linearHeadsOfType<ScheduledState>().values.first()
}
assertEquals(1, countScheduledFlows)
assertEquals(stateFromA, stateFromB, "Must be same copy on both nodes")
assertTrue("Must be processed", stateFromB.state.data.processed)
}

View File

@ -3,6 +3,7 @@ package net.corda.testing.node
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.PartyAndReference
import net.corda.core.crypto.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
@ -46,10 +47,6 @@ import javax.annotation.concurrent.ThreadSafe
* building chains of transactions and verifying them. It isn't sufficient for testing flows however.
*/
open class MockServices(val key: KeyPair = generateKeyPair()) : ServiceHub {
override fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T> {
throw UnsupportedOperationException("not implemented")
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) {
txs.forEach {
storageService.stateMachineRecordedTransactionMapping.addMapping(StateMachineRunId.createRandom(), it.id)