Add information on who started flow on a node. (#549)

* Add information on who started flow on a node with name where possible.
Add sealed class holding information on different ways of starting a flow: RPC, peer, shell, scheduled.

* Remove invokeFlowAsync from ServiceHub, move it to ServiceHubInternal.
We shouldn't be able to start new state machines from inside flows.
This commit is contained in:
Katarzyna Streich 2017-04-24 17:05:51 +01:00 committed by GitHub
parent af7f5ef0d7
commit c1b7b1cb75
17 changed files with 170 additions and 57 deletions

View File

@ -8,6 +8,7 @@ import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.USD import net.corda.core.contracts.USD
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.crypto.keys import net.corda.core.crypto.keys
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
@ -41,11 +42,14 @@ import rx.Observable
class NodeMonitorModelTest : DriverBasedTest() { class NodeMonitorModelTest : DriverBasedTest() {
lateinit var aliceNode: NodeInfo lateinit var aliceNode: NodeInfo
lateinit var bobNode: NodeInfo
lateinit var notaryNode: NodeInfo lateinit var notaryNode: NodeInfo
lateinit var rpc: CordaRPCOps lateinit var rpc: CordaRPCOps
lateinit var rpcBob: CordaRPCOps
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping> lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
lateinit var stateMachineUpdates: Observable<StateMachineUpdate> lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
lateinit var stateMachineUpdatesBob: Observable<StateMachineUpdate>
lateinit var progressTracking: Observable<ProgressTrackingEvent> lateinit var progressTracking: Observable<ProgressTrackingEvent>
lateinit var transactions: Observable<SignedTransaction> lateinit var transactions: Observable<SignedTransaction>
lateinit var vaultUpdates: Observable<Vault.Update> lateinit var vaultUpdates: Observable<Vault.Update>
@ -66,7 +70,6 @@ class NodeMonitorModelTest : DriverBasedTest() {
notaryNode = notaryNodeHandle.nodeInfo notaryNode = notaryNodeHandle.nodeInfo
newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo } newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo }
val monitor = NodeMonitorModel() val monitor = NodeMonitorModel()
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed() stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed()
progressTracking = monitor.progressTracking.bufferUntilSubscribed() progressTracking = monitor.progressTracking.bufferUntilSubscribed()
@ -76,12 +79,18 @@ class NodeMonitorModelTest : DriverBasedTest() {
monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password)
rpc = monitor.proxyObservable.value!! rpc = monitor.proxyObservable.value!!
val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow()
bobNode = bobNodeHandle.nodeInfo
val monitorBob = NodeMonitorModel()
stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed()
monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password)
rpcBob = monitorBob.proxyObservable.value!!
runTest() runTest()
} }
@Test @Test
fun `network map update`() { fun `network map update`() {
newNode(BOB.name)
newNode(CHARLIE.name) newNode(CHARLIE.name)
networkMapUpdates.filter { !it.node.advertisedServices.any { it.info.type.isNotary() } } networkMapUpdates.filter { !it.node.advertisedServices.any { it.info.type.isNotary() } }
.filter { !it.node.advertisedServices.any { it.info.type == NetworkMapService.type } } .filter { !it.node.advertisedServices.any { it.info.type == NetworkMapService.type } }
@ -114,12 +123,12 @@ class NodeMonitorModelTest : DriverBasedTest() {
sequence( sequence(
// SNAPSHOT // SNAPSHOT
expect { output: Vault.Update -> expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size } require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.size == 0) { output.produced.size } require(output.produced.isEmpty()) { output.produced.size }
}, },
// ISSUE // ISSUE
expect { output: Vault.Update -> expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size } require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.size == 1) { output.produced.size } require(output.produced.size == 1) { output.produced.size }
} }
) )
@ -129,7 +138,7 @@ class NodeMonitorModelTest : DriverBasedTest() {
@Test @Test
fun `cash issue and move`() { fun `cash issue and move`() {
rpc.startFlow(::CashIssueFlow, 100.DOLLARS, OpaqueBytes.of(1), aliceNode.legalIdentity, notaryNode.notaryIdentity).returnValue.getOrThrow() rpc.startFlow(::CashIssueFlow, 100.DOLLARS, OpaqueBytes.of(1), aliceNode.legalIdentity, notaryNode.notaryIdentity).returnValue.getOrThrow()
rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, aliceNode.legalIdentity).returnValue.getOrThrow() rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, bobNode.legalIdentity).returnValue.getOrThrow()
var issueSmId: StateMachineRunId? = null var issueSmId: StateMachineRunId? = null
var moveSmId: StateMachineRunId? = null var moveSmId: StateMachineRunId? = null
@ -140,6 +149,8 @@ class NodeMonitorModelTest : DriverBasedTest() {
// ISSUE // ISSUE
expect { add: StateMachineUpdate.Added -> expect { add: StateMachineUpdate.Added ->
issueSmId = add.id issueSmId = add.id
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.RPC && initiator.username == "user1")
}, },
expect { remove: StateMachineUpdate.Removed -> expect { remove: StateMachineUpdate.Removed ->
require(remove.id == issueSmId) require(remove.id == issueSmId)
@ -147,6 +158,8 @@ class NodeMonitorModelTest : DriverBasedTest() {
// MOVE // MOVE
expect { add: StateMachineUpdate.Added -> expect { add: StateMachineUpdate.Added ->
moveSmId = add.id moveSmId = add.id
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.RPC && initiator.username == "user1")
}, },
expect { remove: StateMachineUpdate.Removed -> expect { remove: StateMachineUpdate.Removed ->
require(remove.id == moveSmId) require(remove.id == moveSmId)
@ -154,6 +167,16 @@ class NodeMonitorModelTest : DriverBasedTest() {
) )
} }
stateMachineUpdatesBob.expectEvents {
sequence(
// MOVE
expect { add: StateMachineUpdate.Added ->
val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.Peer && initiator.party.name == aliceNode.legalIdentity.name)
}
)
}
transactions.expectEvents { transactions.expectEvents {
sequence( sequence(
// ISSUE // ISSUE
@ -184,18 +207,18 @@ class NodeMonitorModelTest : DriverBasedTest() {
sequence( sequence(
// SNAPSHOT // SNAPSHOT
expect { output: Vault.Update -> expect { output: Vault.Update ->
require(output.consumed.size == 0) { output.consumed.size } require(output.consumed.isEmpty()) { output.consumed.size }
require(output.produced.size == 0) { output.produced.size } require(output.produced.isEmpty()) { output.produced.size }
}, },
// ISSUE // ISSUE
expect { update -> expect { update ->
require(update.consumed.size == 0) { update.consumed.size } require(update.consumed.isEmpty()) { update.consumed.size }
require(update.produced.size == 1) { update.produced.size } require(update.produced.size == 1) { update.produced.size }
}, },
// MOVE // MOVE
expect { update -> expect { update ->
require(update.consumed.size == 1) { update.consumed.size } require(update.consumed.size == 1) { update.consumed.size }
require(update.produced.size == 1) { update.produced.size } require(update.produced.isEmpty()) { update.produced.size }
} }
) )
} }

View File

@ -1,10 +1,13 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlow
import net.corda.core.messaging.startTrackedFlow import net.corda.core.messaging.startTrackedFlow
import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceInfo
@ -66,10 +69,7 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test @Test
fun `close-send deadlock and premature shutdown on empty observable`() { fun `close-send deadlock and premature shutdown on empty observable`() {
println("Starting client") val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
client.start(rpcUser.username, rpcUser.password)
println("Creating proxy")
val proxy = client.proxy()
println("Starting flow") println("Starting flow")
val flowHandle = proxy.startTrackedFlow( val flowHandle = proxy.startTrackedFlow(
::CashIssueFlow, ::CashIssueFlow,
@ -104,11 +104,7 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test @Test
fun `get cash balances`() { fun `get cash balances`() {
println("Starting client") val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
client.start(rpcUser.username, rpcUser.password)
println("Creating proxy")
val proxy = client.proxy()
val startCash = proxy.getCashBalances() val startCash = proxy.getCashBalances()
assertTrue(startCash.isEmpty(), "Should not start with any cash") assertTrue(startCash.isEmpty(), "Should not start with any cash")
@ -125,4 +121,38 @@ class CordaRPCClientTest : NodeBasedTest() {
assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD"))) assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD")))
} }
@Test
fun `flow initiator via RPC`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
val smUpdates = proxy.stateMachinesAndUpdates()
var countRpcFlows = 0
var countShellFlows = 0
smUpdates.second.subscribe {
if (it is StateMachineUpdate.Added) {
val initiator = it.stateMachineInfo.initiator
if (initiator is FlowInitiator.RPC)
countRpcFlows++
if (initiator is FlowInitiator.Shell)
countShellFlows++
}
}
val nodeIdentity = node.info.legalIdentity
node.services.startFlow(CashIssueFlow(2000.DOLLARS, OpaqueBytes.of(0), nodeIdentity, nodeIdentity), FlowInitiator.Shell).resultFuture.getOrThrow()
proxy.startFlow(::CashIssueFlow,
123.DOLLARS, OpaqueBytes.of(0),
nodeIdentity, nodeIdentity
).returnValue.getOrThrow()
proxy.startFlowDynamic(CashIssueFlow::class.java,
1000.DOLLARS, OpaqueBytes.of(0),
nodeIdentity, nodeIdentity).returnValue.getOrThrow()
assertEquals(2, countRpcFlows)
assertEquals(1, countShellFlows)
}
private fun createRpcProxy(username: String, password: String): CordaRPCOps {
println("Starting client")
client.start(username, password)
println("Creating proxy")
return client.proxy()
}
} }

View File

@ -186,7 +186,7 @@ abstract class FlowLogic<out T> {
return stateMachine.waitForLedgerCommit(hash, this) return stateMachine.waitForLedgerCommit(hash, this)
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////
private var _stateMachine: FlowStateMachine<*>? = null private var _stateMachine: FlowStateMachine<*>? = null
/** /**

View File

@ -2,6 +2,7 @@ package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowHandle
@ -12,6 +13,23 @@ import net.corda.core.utilities.UntrustworthyData
import org.slf4j.Logger import org.slf4j.Logger
import java.util.* import java.util.*
/**
* FlowInitiator holds information on who started the flow. We have different ways of doing that: via RPC [FlowInitiator.RPC],
* communication started by peer node [FlowInitiator.Peer], scheduled flows [FlowInitiator.Scheduled]
* or manual [FlowInitiator.Manual]. The last case is for all flows started in tests, shell etc. It was added
* because we can start flow directly using [StateMachineManager.add] or [ServiceHubInternal.startFlow].
*/
@CordaSerializable
sealed class FlowInitiator {
/** Started using [CordaRPCOps.startFlowDynamic]. */
data class RPC(val username: String) : FlowInitiator()
/** Started when we get new session initiation request. */
data class Peer(val party: Party) : FlowInitiator()
/** Started as scheduled activity. */
class Scheduled(val scheduledState: ScheduledStateRef) : FlowInitiator()
object Shell : FlowInitiator() // TODO When proper ssh access enabled, add username/use RPC?
}
/** /**
* A unique identifier for a single state machine run, valid across node restarts. Note that a single run always * A unique identifier for a single state machine run, valid across node restarts. Note that a single run always
* has at least one flow, but that flow may also invoke sub-flows: they all share the same run id. * has at least one flow, but that flow may also invoke sub-flows: they all share the same run id.
@ -48,4 +66,5 @@ interface FlowStateMachine<R> {
val logger: Logger val logger: Logger
val id: StateMachineRunId val id: StateMachineRunId
val resultFuture: ListenableFuture<R> val resultFuture: ListenableFuture<R>
val flowInitiator: FlowInitiator
} }

View File

@ -7,6 +7,7 @@ import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.UpgradedContract import net.corda.core.contracts.UpgradedContract
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
@ -26,6 +27,7 @@ import java.util.*
data class StateMachineInfo( data class StateMachineInfo(
val id: StateMachineRunId, val id: StateMachineRunId,
val flowLogicClassName: String, val flowLogicClassName: String,
val initiator: FlowInitiator,
val progressTrackerStepAndUpdates: Pair<String, Observable<String>>? val progressTrackerStepAndUpdates: Pair<String, Observable<String>>?
) { ) {
override fun toString(): String = "${javaClass.simpleName}($id, $flowLogicClassName)" override fun toString(): String = "${javaClass.simpleName}($id, $flowLogicClassName)"

View File

@ -2,6 +2,7 @@ package net.corda.core.node
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.keys import net.corda.core.crypto.keys
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
import net.corda.core.messaging.MessagingService import net.corda.core.messaging.MessagingService
@ -86,14 +87,6 @@ interface ServiceHub : ServicesForResolution {
return definingTx.tx.outRef<T>(ref.index) return definingTx.tx.outRef<T>(ref.index)
} }
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
* Note that you must be on the server thread to call this method.
*
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args].
*/
fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T>
/** /**
* Helper property to shorten code for fetching the Node's KeyPair associated with the * Helper property to shorten code for fetching the Node's KeyPair associated with the
* public legalIdentity Party from the key management service. * public legalIdentity Party from the key management service.

View File

@ -10,6 +10,7 @@ import net.corda.core.contracts.Amount
import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.PartyAndReference
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.X509Utilities
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
@ -127,8 +128,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
override val monitoringService: MonitoringService = MonitoringService(MetricRegistry()) override val monitoringService: MonitoringService = MonitoringService(MetricRegistry())
override val flowLogicRefFactory: FlowLogicRefFactory get() = flowLogicFactory override val flowLogicRefFactory: FlowLogicRefFactory get() = flowLogicFactory
override fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> { override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
return serverThread.fetchFrom { smm.add(logic) } return serverThread.fetchFrom { smm.add(logic, flowInitiator) }
} }
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {

View File

@ -5,6 +5,7 @@ import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.UpgradedContract import net.corda.core.contracts.UpgradedContract
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.* import net.corda.core.messaging.*
@ -20,6 +21,7 @@ import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.utilities.AddOrRemove import net.corda.node.utilities.AddOrRemove
import net.corda.node.utilities.transaction import net.corda.node.utilities.transaction
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import net.corda.nodeapi.CURRENT_RPC_USER
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
import java.io.InputStream import java.io.InputStream
@ -61,7 +63,7 @@ class CordaRPCOpsImpl(
return database.transaction { return database.transaction {
val (allStateMachines, changes) = smm.track() val (allStateMachines, changes) = smm.track()
Pair( Pair(
allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic) }, allStateMachines.map { stateMachineInfoFromFlowLogic(it.id, it.logic, it.flowInitiator) },
changes.map { stateMachineUpdateFromStateMachineChange(it) } changes.map { stateMachineUpdateFromStateMachineChange(it) }
) )
} }
@ -98,13 +100,15 @@ class CordaRPCOpsImpl(
// TODO: Check that this flow is annotated as being intended for RPC invocation // TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> { override fun <T : Any> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> {
requirePermission(startFlowPermission(logicType)) requirePermission(startFlowPermission(logicType))
return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = true) as FlowProgressHandle<T> val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = true) as FlowProgressHandle<T>
} }
// TODO: Check that this flow is annotated as being intended for RPC invocation // TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> { override fun <T : Any> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> {
requirePermission(startFlowPermission(logicType)) requirePermission(startFlowPermission(logicType))
return services.invokeFlowAsync(logicType, *args).createHandle(hasProgress = false) val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
return services.invokeFlowAsync(logicType, currentUser, *args).createHandle(hasProgress = false)
} }
override fun attachmentExists(id: SecureHash): Boolean { override fun attachmentExists(id: SecureHash): Boolean {
@ -147,13 +151,13 @@ class CordaRPCOpsImpl(
override fun registeredFlows(): List<String> = services.flowLogicRefFactory.flowWhitelist.keys.sorted() override fun registeredFlows(): List<String> = services.flowLogicRefFactory.flowWhitelist.keys.sorted()
companion object { companion object {
private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>): StateMachineInfo { private fun stateMachineInfoFromFlowLogic(id: StateMachineRunId, flowLogic: FlowLogic<*>, flowInitiator: FlowInitiator): StateMachineInfo {
return StateMachineInfo(id, flowLogic.javaClass.name, flowLogic.track()) return StateMachineInfo(id, flowLogic.javaClass.name, flowInitiator, flowLogic.track())
} }
private fun stateMachineUpdateFromStateMachineChange(change: StateMachineManager.Change): StateMachineUpdate { private fun stateMachineUpdateFromStateMachineChange(change: StateMachineManager.Change): StateMachineUpdate {
return when (change.addOrRemove) { return when (change.addOrRemove) {
AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic)) AddOrRemove.ADD -> StateMachineUpdate.Added(stateMachineInfoFromFlowLogic(change.id, change.logic, change.flowInitiator))
AddOrRemove.REMOVE -> StateMachineUpdate.Removed(change.id) AddOrRemove.REMOVE -> StateMachineUpdate.Removed(change.id)
} }
} }

View File

@ -1,6 +1,8 @@
package net.corda.node.services.api package net.corda.node.services.api
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
@ -64,14 +66,34 @@ abstract class ServiceHubInternal : PluginServiceHub {
} }
/** /**
* Starts an already constructed flow. Note that you must be on the server thread to call this method. * Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator]
* defaults to [FlowInitiator.RPC] with username "Only For Testing".
*/ */
abstract fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> // TODO Move it to test utils.
@VisibleForTesting
fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> = startFlow(logic, FlowInitiator.RPC("Only For Testing"))
override fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T> { /**
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
* @param flowInitiator indicates who started the flow, see: [FlowInitiator].
*/
abstract fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T>
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
* Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started,
* See: [FlowInitiator].
*
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args].
*/
fun <T : Any> invokeFlowAsync(
logicType: Class<out FlowLogic<T>>,
flowInitiator: FlowInitiator,
vararg args: Any?): FlowStateMachine<T> {
val logicRef = flowLogicRefFactory.create(logicType, *args) val logicRef = flowLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T> val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T>
return startFlow(logic) return startFlow(logic, flowInitiator)
} }
} }

View File

@ -7,6 +7,7 @@ import net.corda.core.contracts.SchedulableState
import net.corda.core.contracts.ScheduledActivity import net.corda.core.contracts.ScheduledActivity
import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.node.services.SchedulerService import net.corda.core.node.services.SchedulerService
@ -158,7 +159,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
} }
private fun onTimeReached(scheduledState: ScheduledStateRef) { private fun onTimeReached(scheduledState: ScheduledStateRef) {
services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService)) services.startFlow(RunScheduled(scheduledState, this@NodeSchedulerService), FlowInitiator.Scheduled(scheduledState))
} }
class RunScheduled(val scheduledState: ScheduledStateRef, val scheduler: NodeSchedulerService) : FlowLogic<Unit>() { class RunScheduled(val scheduledState: ScheduledStateRef, val scheduler: NodeSchedulerService) : FlowLogic<Unit>() {
@ -167,7 +168,6 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
fun tracker() = ProgressTracker(RUNNING) fun tracker() = ProgressTracker(RUNNING)
} }
override val progressTracker = tracker() override val progressTracker = tracker()
@Suspendable @Suspendable

View File

@ -11,6 +11,7 @@ import net.corda.core.abbreviate
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
@ -39,7 +40,8 @@ import java.util.concurrent.TimeUnit
class FlowStateMachineImpl<R>(override val id: StateMachineRunId, class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val logic: FlowLogic<R>, val logic: FlowLogic<R>,
scheduler: FiberScheduler) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> { scheduler: FiberScheduler,
override val flowInitiator: FlowInitiator) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
companion object { companion object {
// Used to work around a small limitation in Quasar. // Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run { private val QUASAR_UNBLOCKER = run {

View File

@ -18,6 +18,7 @@ import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.commonName import net.corda.core.crypto.commonName
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
@ -113,7 +114,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
data class Change( data class Change(
val logic: FlowLogic<*>, val logic: FlowLogic<*>,
val addOrRemove: AddOrRemove, val addOrRemove: AddOrRemove,
val id: StateMachineRunId val id: StateMachineRunId,
val flowInitiator: FlowInitiator
) )
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
@ -125,7 +127,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val fibersWaitingForLedgerCommit = HashMultimap.create<SecureHash, FlowStateMachineImpl<*>>()!! val fibersWaitingForLedgerCommit = HashMultimap.create<SecureHash, FlowStateMachineImpl<*>>()!!
fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) { fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id)) changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id, fiber.flowInitiator))
} }
} }
@ -359,7 +361,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val session = try { val session = try {
val flow = flowFactory(sender) val flow = flowFactory(sender)
val fiber = createFiber(flow) val fiber = createFiber(flow, FlowInitiator.Peer(sender))
val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId)) val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId))
if (sessionInit.firstPayload != null) { if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
@ -398,9 +400,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun quasarKryo(): KryoPool = quasarKryoPool private fun quasarKryo(): KryoPool = quasarKryoPool
private fun <T> createFiber(logic: FlowLogic<T>): FlowStateMachineImpl<T> { private fun <T> createFiber(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachineImpl<T> {
val id = StateMachineRunId.createRandom() val id = StateMachineRunId.createRandom()
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) } return FlowStateMachineImpl(id, logic, scheduler, flowInitiator).apply { initFiber(this) }
} }
private fun initFiber(fiber: FlowStateMachineImpl<*>) { private fun initFiber(fiber: FlowStateMachineImpl<*>) {
@ -471,7 +473,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
* *
* Note that you must be on the [executor] thread. * Note that you must be on the [executor] thread.
*/ */
fun <T> add(logic: FlowLogic<T>): FlowStateMachine<T> { fun <T> add(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
// TODO: Check that logic has @Suspendable on its call method. // TODO: Check that logic has @Suspendable on its call method.
executor.checkOnThread() executor.checkOnThread()
// We swap out the parent transaction context as using this frequently leads to a deadlock as we wait // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait
@ -479,7 +481,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// unable to acquire the table lock and move forward till the calling transaction finishes. // unable to acquire the table lock and move forward till the calling transaction finishes.
// Committing in line here on a fresh context ensure we can progress. // Committing in line here on a fresh context ensure we can progress.
val fiber = database.isolatedTransaction { val fiber = database.isolatedTransaction {
val fiber = createFiber(logic) val fiber = createFiber(logic, flowInitiator)
updateCheckpoint(fiber) updateCheckpoint(fiber)
fiber fiber
} }

View File

@ -10,6 +10,7 @@ import com.google.common.io.Closeables
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import net.corda.core.* import net.corda.core.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
@ -221,8 +222,9 @@ object InteractiveShell {
if (!FlowLogic::class.java.isAssignableFrom(clazz)) if (!FlowLogic::class.java.isAssignableFrom(clazz))
throw IllegalStateException("Found a non-FlowLogic class in the whitelist? $clazz") throw IllegalStateException("Found a non-FlowLogic class in the whitelist? $clazz")
try { try {
// TODO Flow invocation should use startFlowDynamic.
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val fsm = runFlowFromString({ node.services.startFlow(it) }, inputData, clazz as Class<FlowLogic<*>>) val fsm = runFlowFromString({ node.services.startFlow(it, FlowInitiator.Shell) }, inputData, clazz as Class<FlowLogic<*>>)
// Show the progress tracker on the console until the flow completes or is interrupted with a // Show the progress tracker on the console until the flow completes or is interrupted with a
// Ctrl-C keypress. // Ctrl-C keypress.
val latch = CountDownLatch(1) val latch = CountDownLatch(1)

View File

@ -5,6 +5,7 @@ import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.Amount import net.corda.core.contracts.Amount
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator
import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.X509Utilities
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
@ -97,5 +98,7 @@ class InteractiveShellTest {
get() = throw UnsupportedOperationException() get() = throw UnsupportedOperationException()
override val resultFuture: ListenableFuture<Any?> override val resultFuture: ListenableFuture<Any?>
get() = throw UnsupportedOperationException() get() = throw UnsupportedOperationException()
override val flowInitiator: FlowInitiator
get() = throw UnsupportedOperationException()
} }
} }

View File

@ -2,6 +2,7 @@ package net.corda.node.services
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
@ -81,8 +82,8 @@ open class MockServiceHubInternal(
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs) override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> { override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachine<T> {
return smm.executor.fetchFrom { smm.add(logic) } return smm.executor.fetchFrom { smm.add(logic, flowInitiator) }
} }
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {

View File

@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.containsAny import net.corda.core.crypto.containsAny
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.node.CordaPluginRegistry import net.corda.core.node.CordaPluginRegistry
@ -13,6 +14,7 @@ import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.FinalityFlow import net.corda.flows.FinalityFlow
import net.corda.node.services.network.NetworkMapService import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.node.utilities.AddOrRemove
import net.corda.node.utilities.transaction import net.corda.node.utilities.transaction
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import org.junit.After import org.junit.After
@ -112,6 +114,15 @@ class ScheduledFlowTests {
@Test @Test
fun `create and run scheduled flow then wait for result`() { fun `create and run scheduled flow then wait for result`() {
val stateMachines = nodeA.smm.track()
var countScheduledFlows = 0
stateMachines.second.subscribe {
if (it.addOrRemove == AddOrRemove.ADD) {
val initiator = it.flowInitiator
if (initiator is FlowInitiator.Scheduled)
countScheduledFlows++
}
}
nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity)) nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity))
net.waitQuiescent() net.waitQuiescent()
val stateFromA = nodeA.database.transaction { val stateFromA = nodeA.database.transaction {
@ -120,6 +131,7 @@ class ScheduledFlowTests {
val stateFromB = nodeB.database.transaction { val stateFromB = nodeB.database.transaction {
nodeB.services.vaultService.linearHeadsOfType<ScheduledState>().values.first() nodeB.services.vaultService.linearHeadsOfType<ScheduledState>().values.first()
} }
assertEquals(1, countScheduledFlows)
assertEquals(stateFromA, stateFromB, "Must be same copy on both nodes") assertEquals(stateFromA, stateFromB, "Must be same copy on both nodes")
assertTrue("Must be processed", stateFromB.state.data.processed) assertTrue("Must be processed", stateFromB.state.data.processed)
} }

View File

@ -3,6 +3,7 @@ package net.corda.testing.node
import net.corda.core.contracts.Attachment import net.corda.core.contracts.Attachment
import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.PartyAndReference
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
@ -46,10 +47,6 @@ import javax.annotation.concurrent.ThreadSafe
* building chains of transactions and verifying them. It isn't sufficient for testing flows however. * building chains of transactions and verifying them. It isn't sufficient for testing flows however.
*/ */
open class MockServices(val key: KeyPair = generateKeyPair()) : ServiceHub { open class MockServices(val key: KeyPair = generateKeyPair()) : ServiceHub {
override fun <T : Any> invokeFlowAsync(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowStateMachine<T> {
throw UnsupportedOperationException("not implemented")
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) { override fun recordTransactions(txs: Iterable<SignedTransaction>) {
txs.forEach { txs.forEach {
storageService.stateMachineRecordedTransactionMapping.addMapping(StateMachineRunId.createRandom(), it.id) storageService.stateMachineRecordedTransactionMapping.addMapping(StateMachineRunId.createRandom(), it.id)