diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 7c2e887573..14f2aaf9dd 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -1074,9 +1074,13 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object public () @org.jetbrains.annotations.NotNull public abstract net.corda.core.identity.Party getCounterparty() @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.flows.FlowInfo getCounterpartyFlowInfo() + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.flows.FlowInfo getCounterpartyFlowInfo(boolean) @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.UntrustworthyData receive(Class) + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.UntrustworthyData receive(Class, boolean) @co.paralleluniverse.fibers.Suspendable public abstract void send(Object) + @co.paralleluniverse.fibers.Suspendable public abstract void send(Object, boolean) @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.UntrustworthyData sendAndReceive(Class, Object) + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.UntrustworthyData sendAndReceive(Class, Object, boolean) ## public final class net.corda.core.flows.FlowStackSnapshot extends java.lang.Object public (java.time.Instant, String, List) diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt index 277769354c..9479ffe36a 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt @@ -2,6 +2,7 @@ package net.corda.client.rpc import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowInitiator +import net.corda.core.internal.concurrent.flatMap import net.corda.core.internal.packageName import net.corda.core.messaging.FlowProgressHandle import net.corda.core.messaging.StateMachineUpdate @@ -143,7 +144,7 @@ class CordaRPCClientTest : NodeBasedTest(listOf("net.corda.finance.contracts", C } } val nodeIdentity = node.info.chooseIdentity() - node.services.startFlow(CashIssueFlow(2000.DOLLARS, OpaqueBytes.of(0), nodeIdentity), FlowInitiator.Shell).resultFuture.getOrThrow() + node.services.startFlow(CashIssueFlow(2000.DOLLARS, OpaqueBytes.of(0), nodeIdentity), FlowInitiator.Shell).flatMap { it.resultFuture }.getOrThrow() proxy.startFlow(::CashIssueFlow, 123.DOLLARS, OpaqueBytes.of(0), diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 0239c81f20..322a9f9925 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -41,6 +41,16 @@ import java.time.Instant * and request they start their counterpart flow, then make sure it's annotated with [InitiatingFlow]. This annotation * also has a version property to allow you to version your flow and enables a node to restrict support for the flow to * that particular version. + * + * Functions that suspend the flow (including all functions on [FlowSession]) accept a [maySkipCheckpoint] parameter + * defaulting to false, false meaning a checkpoint should always be created on suspend. This parameter may be set to + * true which allows the implementation to potentially optimise away the checkpoint, saving a roundtrip to the database. + * + * This option however comes with a big warning sign: Setting the parameter to true requires the flow's code to be + * replayable from the previous checkpoint (or start of flow) up until the next checkpoint (or end of flow) in order to + * prepare for hard failures. As suspending functions always commit the flow's database transaction regardless of this + * parameter the flow must be prepared for scenarios where a previous running of the flow *already committed its + * relevant database transactions*. Only set this option to true if you know what you're doing. */ abstract class FlowLogic { /** This is where you should log things to. */ @@ -123,7 +133,7 @@ abstract class FlowLogic { */ @Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING) @Suspendable - fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions) + fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false) /** * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response @@ -157,7 +167,7 @@ abstract class FlowLogic { @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) @Suspendable open fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions) + return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false) } /** @@ -171,17 +181,17 @@ abstract class FlowLogic { */ @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true) + return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) } @Suspendable internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true) + return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) } @Suspendable internal inline fun FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true) + return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) } /** @@ -206,7 +216,7 @@ abstract class FlowLogic { @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) @Suspendable open fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { - return stateMachine.receive(receiveType, otherParty, flowUsedForSessions) + return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false) } /** Suspends until a message has been received for each session in the specified [sessions]. @@ -250,7 +260,9 @@ abstract class FlowLogic { */ @Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING) @Suspendable - open fun send(otherParty: Party, payload: Any) = stateMachine.send(otherParty, payload, flowUsedForSessions) + open fun send(otherParty: Party, payload: Any) { + stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false) + } /** * Invokes the given subflow. This function returns once the subflow completes successfully with the result @@ -342,7 +354,10 @@ abstract class FlowLogic { * valid by the local node, but that doesn't imply the vault will consider it relevant. */ @Suspendable - fun waitForLedgerCommit(hash: SecureHash): SignedTransaction = stateMachine.waitForLedgerCommit(hash, this) + @JvmOverloads + fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction { + return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint) + } /** * Returns a shallow copy of the Quasar stack frames at the time of call to [flowStackSnapshot]. Use this to inspect diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt b/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt index ae49ea3ebb..b1782f5424 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt @@ -54,8 +54,20 @@ abstract class FlowSession { * Returns a [FlowInfo] object describing the flow [counterparty] is using. With [FlowInfo.flowVersion] it * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. * - * This method can be called before any send or receive has been done with [counterparty]. In such a case this will force - * them to start their flow. + * This method can be called before any send or receive has been done with [counterparty]. In such a case this will + * force them to start their flow. + * + * @param maySkipCheckpoint setting it to true indicates to the platform that it may optimise away the checkpoint. + */ + @Suspendable + abstract fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo + + /** + * Returns a [FlowInfo] object describing the flow [counterparty] is using. With [FlowInfo.flowVersion] it + * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. + * + * This method can be called before any send or receive has been done with [counterparty]. In such a case this will + * force them to start their flow. */ @Suspendable abstract fun getCounterpartyFlowInfo(): FlowInfo @@ -80,8 +92,26 @@ abstract class FlowSession { /** * Serializes and queues the given [payload] object for sending to the [counterparty]. Suspends until a response - * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data - * should not be trusted until it's been thoroughly verified for consistency and that all expectations are + * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the + * data should not be trusted until it's been thoroughly verified for consistency and that all expectations are + * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. + * + * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to + * use this when you expect to do a message swap than do use [send] and then [receive] in turn. + * + * @param maySkipCheckpoint setting it to true indicates to the platform that it may optimise away the checkpoint. + * @return an [UntrustworthyData] wrapper around the received object. + */ + @Suspendable + abstract fun sendAndReceive( + receiveType: Class, + payload: Any, maySkipCheckpoint: Boolean + ): UntrustworthyData + + /** + * Serializes and queues the given [payload] object for sending to the [counterparty]. Suspends until a response + * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the + * data should not be trusted until it's been thoroughly verified for consistency and that all expectations are * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. * * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to @@ -104,6 +134,19 @@ abstract class FlowSession { return receive(R::class.java) } + /** + * Suspends until [counterparty] sends us a message of type [receiveType]. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @param maySkipCheckpoint setting it to true indicates to the platform that it may optimise away the checkpoint. + * @return an [UntrustworthyData] wrapper around the received object. + */ + @Suspendable + abstract fun receive(receiveType: Class, maySkipCheckpoint: Boolean): UntrustworthyData + /** * Suspends until [counterparty] sends us a message of type [receiveType]. * @@ -116,6 +159,18 @@ abstract class FlowSession { @Suspendable abstract fun receive(receiveType: Class): UntrustworthyData + /** + * Queues the given [payload] for sending to the [counterparty] and continues without suspending. + * + * Note that the other party may receive the message at some arbitrary later point or not at all: if [counterparty] + * is offline then message delivery will be retried until it comes back or until the message is older than the + * network's event horizon time. + * + * @param maySkipCheckpoint setting it to true indicates to the platform that it may optimise away the checkpoint. + */ + @Suspendable + abstract fun send(payload: Any, maySkipCheckpoint: Boolean) + /** * Queues the given [payload] for sending to the [counterparty] and continues without suspending. * diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index a2b0e2fd15..5e4f3e490a 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -15,7 +15,7 @@ import java.time.Instant /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ interface FlowStateMachine { @Suspendable - fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>): FlowInfo + fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo @Suspendable fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession @@ -25,16 +25,17 @@ interface FlowStateMachine { otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, - retrySend: Boolean = false): UntrustworthyData + retrySend: Boolean, + maySkipCheckpoint: Boolean): UntrustworthyData @Suspendable - fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData + fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData @Suspendable - fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) + fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) @Suspendable - fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction + fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction @Suspendable fun sleepUntil(until: Instant) diff --git a/docs/source/example-code/src/test/kotlin/net/corda/docs/CustomVaultQueryTest.kt b/docs/source/example-code/src/test/kotlin/net/corda/docs/CustomVaultQueryTest.kt index 509cb65312..9690a5ab41 100644 --- a/docs/source/example-code/src/test/kotlin/net/corda/docs/CustomVaultQueryTest.kt +++ b/docs/source/example-code/src/test/kotlin/net/corda/docs/CustomVaultQueryTest.kt @@ -27,12 +27,18 @@ class CustomVaultQueryTest { @Before fun setup() { - mockNet = MockNetwork(threadPerNode = true, cordappPackages = listOf("net.corda.finance.contracts.asset", CashSchemaV1::class.packageName)) + mockNet = MockNetwork( + threadPerNode = true, + cordappPackages = listOf( + "net.corda.finance.contracts.asset", + CashSchemaV1::class.packageName, + "net.corda.docs" + ) + ) mockNet.createNotaryNode(legalName = DUMMY_NOTARY.name) nodeA = mockNet.createPartyNode() nodeB = mockNet.createPartyNode() nodeA.internals.registerInitiatedFlow(TopupIssuerFlow.TopupIssuer::class.java) - nodeA.installCordaService(CustomVaultQuery.Service::class.java) notary = nodeA.services.getDefaultNotary() } diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index dee4123fd0..d53535380d 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -11,13 +11,10 @@ import net.corda.core.flows.* import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.VisibleForTesting -import net.corda.core.internal.cert +import net.corda.core.internal.* import net.corda.core.internal.concurrent.doneFuture import net.corda.core.internal.concurrent.flatMap import net.corda.core.internal.concurrent.openFuture -import net.corda.core.internal.toX509CertHolder -import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.* import net.corda.core.node.AppServiceHub import net.corda.core.node.NodeInfo @@ -30,6 +27,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug +import net.corda.core.utilities.getOrThrow import net.corda.node.VersionInfo import net.corda.node.internal.classloading.requireAnnotation import net.corda.node.internal.cordapp.CordappLoader @@ -146,6 +144,7 @@ abstract class AbstractNode(config: NodeConfiguration, protected var myNotaryIdentity: PartyAndCertificate? = null protected lateinit var checkpointStorage: CheckpointStorage protected lateinit var smm: StateMachineManager + private lateinit var tokenizableServices: List protected lateinit var attachments: NodeAttachmentService protected lateinit var inNodeNetworkMapService: NetworkMapService protected lateinit var network: MessagingService @@ -209,18 +208,11 @@ abstract class AbstractNode(config: NodeConfiguration, val startedImpl = initialiseDatabasePersistence(schemaService) { val transactionStorage = makeTransactionStorage() val stateLoader = StateLoaderImpl(transactionStorage) - val tokenizableServices = makeServices(schemaService, transactionStorage, stateLoader) + val services = makeServices(schemaService, transactionStorage, stateLoader) saveOwnNodeInfo() - smm = StateMachineManager(services, - checkpointStorage, - serverThread, - database, - busyNodeLatch, - cordappLoader.appClassLoader) + smm = makeStateMachineManager() val flowStarter = FlowStarterImpl(serverThread, smm) val schedulerService = NodeSchedulerService(platformClock, this@AbstractNode.database, flowStarter, stateLoader, unfinishedSchedules = busyNodeLatch, serverThread = serverThread) - smm.tokenizableServices.addAll(tokenizableServices) - smm.tokenizableServices.add(schedulerService) if (serverThread is ExecutorService) { runOnStop += { // We wait here, even though any in-flight messages should have been drained away because the @@ -233,7 +225,8 @@ abstract class AbstractNode(config: NodeConfiguration, val rpcOps = makeRPCOps(flowStarter) startMessagingService(rpcOps) installCoreFlows() - installCordaServices(flowStarter) + val cordaServices = installCordaServices(flowStarter) + tokenizableServices = services + cordaServices + schedulerService registerCordappFlows() _services.rpcFlows += cordappLoader.cordapps.flatMap { it.rpcFlows } FlowLogicRefFactoryImpl.classloader = cordappLoader.appClassLoader @@ -245,7 +238,7 @@ abstract class AbstractNode(config: NodeConfiguration, _nodeReadyFuture.captureLater(registerWithNetworkMapIfConfigured()) return startedImpl.apply { database.transaction { - smm.start() + smm.start(tokenizableServices) // Shut down the SMM so no Fibers are scheduled. runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) } schedulerService.start() @@ -254,20 +247,34 @@ abstract class AbstractNode(config: NodeConfiguration, } } + protected open fun makeStateMachineManager(): StateMachineManager { + return StateMachineManagerImpl( + services, + checkpointStorage, + serverThread, + database, + busyNodeLatch, + cordappLoader.appClassLoader + ) + } + private class ServiceInstantiationException(cause: Throwable?) : CordaException("Service Instantiation Error", cause) - private fun installCordaServices(flowStarter: FlowStarter) { + private fun installCordaServices(flowStarter: FlowStarter): List { val loadedServices = cordappLoader.cordapps.flatMap { it.services } - filterServicesToInstall(loadedServices).forEach { + return filterServicesToInstall(loadedServices).mapNotNull { try { installCordaService(flowStarter, it) } catch (e: NoSuchMethodException) { log.error("${it.name}, as a Corda service, must have a constructor with a single parameter of type " + ServiceHub::class.java.name) + null } catch (e: ServiceInstantiationException) { log.error("Corda service ${it.name} failed to instantiate", e.cause) + null } catch (e: Exception) { log.error("Unable to install Corda service ${it.name}", e) + null } } } @@ -309,11 +316,11 @@ abstract class AbstractNode(config: NodeConfiguration, return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture) } - private fun startFlowChecked(flow: FlowLogic): FlowStateMachineImpl { + private fun startFlowChecked(flow: FlowLogic): FlowStateMachine { val logicType = flow.javaClass require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" } val currentUser = FlowInitiator.Service(serviceInstance.javaClass.name) - return flowStarter.startFlow(flow, currentUser) + return flowStarter.startFlow(flow, currentUser).getOrThrow() } override fun equals(other: Any?): Boolean { @@ -327,7 +334,7 @@ abstract class AbstractNode(config: NodeConfiguration, override fun hashCode() = Objects.hash(serviceHub, flowStarter, serviceInstance) } - internal fun installCordaService(flowStarter: FlowStarter, serviceClass: Class): T { + private fun installCordaService(flowStarter: FlowStarter, serviceClass: Class): T { serviceClass.requireAnnotation() val service = try { val serviceContext = AppServiceHubImpl(services, flowStarter) @@ -351,7 +358,6 @@ abstract class AbstractNode(config: NodeConfiguration, throw ServiceInstantiationException(e.cause) } cordappServices.putInstance(serviceClass, service) - smm.tokenizableServices += service if (service is NotaryService) handleCustomNotaryService(service) @@ -359,6 +365,12 @@ abstract class AbstractNode(config: NodeConfiguration, return service } + fun findTokenizableService(clazz: Class): T? { + return tokenizableServices.firstOrNull { clazz.isAssignableFrom(it.javaClass) }?.let { uncheckedCast(it) } + } + + inline fun findTokenizableService() = findTokenizableService(T::class.java) + private fun handleCustomNotaryService(service: NotaryService) { runOnStop += service::stop service.start() @@ -801,7 +813,7 @@ abstract class AbstractNode(config: NodeConfiguration, } internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager) : FlowStarter { - override fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party?): FlowStateMachineImpl { - return serverThread.fetchFrom { smm.add(logic, flowInitiator, ourIdentity) } + override fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party?): CordaFuture> { + return serverThread.fetchFrom { smm.startFlow(logic, flowInitiator, ourIdentity) } } } diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index af827b9e39..863835c25a 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -10,6 +10,7 @@ import net.corda.core.flows.StartableByRPC import net.corda.core.identity.AbstractParty import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.* import net.corda.core.node.NodeInfo import net.corda.core.node.services.NetworkMapCache @@ -18,12 +19,12 @@ import net.corda.core.node.services.vault.PageSpecification import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.node.services.vault.Sort import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow import net.corda.node.services.FlowPermissions.Companion.startFlowPermission import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.messaging.getRpcContext import net.corda.node.services.messaging.requirePermission -import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.utilities.CordaPersistence import rx.Observable @@ -94,7 +95,7 @@ class CordaRPCOpsImpl( return database.transaction { val (allStateMachines, changes) = smm.track() DataFeed( - allStateMachines.map { stateMachineInfoFromFlowLogic(it.logic) }, + allStateMachines.map { stateMachineInfoFromFlowLogic(it) }, changes.map { stateMachineUpdateFromStateMachineChange(it) } ) } @@ -146,13 +147,13 @@ class CordaRPCOpsImpl( return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture) } - private fun startFlow(logicType: Class>, args: Array): FlowStateMachineImpl { + private fun startFlow(logicType: Class>, args: Array): FlowStateMachine { require(logicType.isAnnotationPresent(StartableByRPC::class.java)) { "${logicType.name} was not designed for RPC" } val rpcContext = getRpcContext() rpcContext.requirePermission(startFlowPermission(logicType)) val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username) // TODO RPC flows should have mapping user -> identity that should be resolved automatically on starting flow. - return flowStarter.invokeFlowAsync(logicType, currentUser, *args) + return flowStarter.invokeFlowAsync(logicType, currentUser, *args).getOrThrow() } override fun attachmentExists(id: SecureHash): Boolean { diff --git a/node/src/main/kotlin/net/corda/node/internal/StartedNode.kt b/node/src/main/kotlin/net/corda/node/internal/StartedNode.kt index 2d0645cef4..71b80d5aea 100644 --- a/node/src/main/kotlin/net/corda/node/internal/StartedNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/StartedNode.kt @@ -31,11 +31,6 @@ interface StartedNode { val rpcOps: CordaRPCOps fun dispose() = internals.stop() fun > registerInitiatedFlow(initiatedFlowClass: Class) = internals.registerInitiatedFlow(initiatedFlowClass) - /** - * Use this method to install your Corda services in your tests. This is automatically done by the node when it - * starts up for all classes it finds which are annotated with [CordaService]. - */ - fun installCordaService(serviceClass: Class) = internals.installCordaService(services, serviceClass) } class StateLoaderImpl(private val validatedTransactions: TransactionStorage) : StateLoader { diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index de7b8eb2eb..de12960304 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -20,6 +20,7 @@ import net.corda.core.node.services.NetworkMapCacheBase import net.corda.core.node.services.TransactionStorage import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.cordapp.CordappProviderInternal @@ -119,13 +120,13 @@ interface FlowStarter { * defaults to [FlowInitiator.RPC] with username "Only For Testing". */ @VisibleForTesting - fun startFlow(logic: FlowLogic): FlowStateMachine = startFlow(logic, FlowInitiator.RPC("Only For Testing")) + fun startFlow(logic: FlowLogic): FlowStateMachine = startFlow(logic, FlowInitiator.RPC("Only For Testing")).getOrThrow() /** * Starts an already constructed flow. Note that you must be on the server thread to call this method. * @param flowInitiator indicates who started the flow, see: [FlowInitiator]. */ - fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): FlowStateMachineImpl + fun startFlow(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): CordaFuture> /** * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. @@ -138,7 +139,7 @@ interface FlowStarter { fun invokeFlowAsync( logicType: Class>, flowInitiator: FlowInitiator, - vararg args: Any?): FlowStateMachineImpl { + vararg args: Any?): CordaFuture> { val logicRef = FlowLogicRefFactoryImpl.createForRPC(logicType, *args) val logic: FlowLogic = uncheckedCast(FlowLogicRefFactoryImpl.toFlowLogic(logicRef)) return startFlow(logic, flowInitiator, ourIdentity = null) diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index a0c208a53e..85aa36a0cb 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -1,9 +1,8 @@ package net.corda.node.services.events import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.strands.SettableFuture as QuasarSettableFuture import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture as GuavaSettableFuture +import com.google.common.util.concurrent.SettableFuture import net.corda.core.contracts.SchedulableState import net.corda.core.contracts.ScheduledActivity import net.corda.core.contracts.ScheduledStateRef @@ -13,6 +12,7 @@ import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.internal.ThreadBox import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.concurrent.flatMap import net.corda.core.internal.until import net.corda.core.node.StateLoader import net.corda.core.schemas.PersistentStateRef @@ -36,6 +36,8 @@ import javax.annotation.concurrent.ThreadSafe import javax.persistence.Column import javax.persistence.EmbeddedId import javax.persistence.Entity +import co.paralleluniverse.strands.SettableFuture as QuasarSettableFuture +import com.google.common.util.concurrent.SettableFuture as GuavaSettableFuture /** * A first pass of a simple [SchedulerService] that works with [MutableClock]s for testing, demonstrations and simulations @@ -215,7 +217,7 @@ class NodeSchedulerService(private val clock: Clock, * cancelled then we run the scheduled action. Finally we remove that action from the scheduled actions and * recompute the next scheduled action. */ - internal fun rescheduleWakeUp() { + private fun rescheduleWakeUp() { // Note, we already have the mutex but we need the scope again here val (scheduledState, ourRescheduledFuture) = mutex.alreadyLocked { rescheduled?.cancel(false) @@ -245,7 +247,7 @@ class NodeSchedulerService(private val clock: Clock, val scheduledFlow = getScheduledFlow(scheduledState) if (scheduledFlow != null) { flowName = scheduledFlow.javaClass.name - val future = flowStarter.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).resultFuture + val future = flowStarter.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).flatMap { it.resultFuture } future.then { unfinishedSchedules.countDown() } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index a708166239..1697b49047 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -83,9 +83,40 @@ interface MessagingService { * to send an ACK message back. * * @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id. - * Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary. + * Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary. + * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two + * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same + * sequence the send()s were called. By default this is chosen conservatively to be [target]. + * @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by + * the broker. Note that if specified [send] itself may return earlier than the commit. */ - fun send(message: Message, target: MessageRecipients, retryId: Long? = null) + fun send( + message: Message, + target: MessageRecipients, + retryId: Long? = null, + sequenceKey: Any = target, + acknowledgementHandler: (() -> Unit)? = null + ) + + /** A message with a target and sequenceKey specified. */ + data class AddressedMessage( + val message: Message, + val target: MessageRecipients, + val retryId: Long? = null, + val sequenceKey: Any = target + ) + + /** + * Sends a list of messages to the specified recipients. This function allows for an efficient batching + * implementation. + * + * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. + * @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id. + * Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary. + * @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed + * by the broker. Note that if specified [send] itself may return earlier than the commit. + */ + fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)? = null) /** Cancels the scheduled message redelivery for the specified [retryId] */ fun cancelRedelivery(retryId: Long) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt index 5a9cd04730..401709f811 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt @@ -22,7 +22,7 @@ import net.corda.node.services.RPCUserService import net.corda.node.services.api.MonitoringService import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.VerifierType -import net.corda.node.services.statemachine.StateMachineManager +import net.corda.node.services.statemachine.StateMachineManagerImpl import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.transactions.OutOfProcessTransactionVerifierService import net.corda.node.utilities.* @@ -485,7 +485,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { // We have to perform sending on a different thread pool, since using the same pool for messaging and // fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. messagingExecutor.fetchFrom { @@ -502,7 +502,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString())) // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended - if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManager.sessionTopic.topic) { + if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) { putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) } } @@ -523,6 +523,14 @@ class NodeMessagingClient(override val config: NodeConfiguration, } } } + acknowledgementHandler?.invoke() + } + + override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + for ((message, target, retryId, sequenceKey) in addressedMessages) { + send(message, target, retryId, sequenceKey, null) + } + acknowledgementHandler?.invoke() } private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt index 044fdc0dbe..054d7c5d01 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt @@ -16,28 +16,46 @@ class FlowSessionImpl( internal lateinit var sessionFlow: FlowLogic<*> @Suspendable - override fun getCounterpartyFlowInfo(): FlowInfo { - return stateMachine.getFlowInfo(counterparty, sessionFlow) + override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo { + return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint) } @Suspendable - override fun sendAndReceive(receiveType: Class, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, counterparty, payload, sessionFlow) + override fun getCounterpartyFlowInfo() = getCounterpartyFlowInfo(maySkipCheckpoint = false) + + @Suspendable + override fun sendAndReceive( + receiveType: Class, + payload: Any, + maySkipCheckpoint: Boolean + ): UntrustworthyData { + return stateMachine.sendAndReceive( + receiveType, + counterparty, + payload, + sessionFlow, + retrySend = false, + maySkipCheckpoint = maySkipCheckpoint + ) } @Suspendable - internal fun sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, counterparty, payload, sessionFlow, retrySend = true) + override fun sendAndReceive(receiveType: Class, payload: Any) = sendAndReceive(receiveType, payload, maySkipCheckpoint = false) + + @Suspendable + override fun receive(receiveType: Class, maySkipCheckpoint: Boolean): UntrustworthyData { + return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint) } @Suspendable - override fun receive(receiveType: Class): UntrustworthyData { - return stateMachine.receive(receiveType, counterparty, sessionFlow) + override fun receive(receiveType: Class) = receive(receiveType, maySkipCheckpoint = false) + + @Suspendable + override fun send(payload: Any, maySkipCheckpoint: Boolean) { + return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint) } @Suspendable - override fun send(payload: Any) { - return stateMachine.send(counterparty, payload, sessionFlow) - } + override fun send(payload: Any) = send(payload, maySkipCheckpoint = false) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index a3df2461c8..57cc99378a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -163,7 +163,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>): FlowInfo { + override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo { val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated return state.context } @@ -173,7 +173,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, - retrySend: Boolean): UntrustworthyData { + retrySend: Boolean, + maySkipCheckpoint: Boolean): UntrustworthyData { requireNonPrimitive(receiveType) logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) @@ -192,7 +193,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable override fun receive(receiveType: Class, otherParty: Party, - sessionFlow: FlowLogic<*>): UntrustworthyData { + sessionFlow: FlowLogic<*>, + maySkipCheckpoint: Boolean): UntrustworthyData { requireNonPrimitive(receiveType) logger.debug { "receive(${receiveType.name}, $otherParty) ..." } val session = getConfirmedSession(otherParty, sessionFlow) @@ -208,7 +210,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { + override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) { logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" } val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) if (session == null) { @@ -220,7 +222,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction { + override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction { logger.debug { "waitForLedgerCommit($hash) ..." } suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>)) val stx = serviceHub.validatedTransactions.getTransaction(hash) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 74697821e9..e790d28b2c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -1,68 +1,24 @@ package net.corda.node.services.statemachine -import co.paralleluniverse.fibers.Fiber -import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.fibers.instrument.SuspendableHelper -import co.paralleluniverse.strands.Strand -import com.codahale.metrics.Gauge -import com.esotericsoftware.kryo.KryoException -import com.google.common.collect.HashMultimap -import com.google.common.util.concurrent.MoreExecutors -import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.random63BitValue -import net.corda.core.flows.* +import net.corda.core.flows.FlowInitiator +import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party -import net.corda.core.internal.* +import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT -import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize import net.corda.core.utilities.Try -import net.corda.core.utilities.debug -import net.corda.core.utilities.loggerFor -import net.corda.core.utilities.trace -import net.corda.node.internal.InitiatedFlowFactory -import net.corda.node.services.api.Checkpoint -import net.corda.node.services.api.CheckpointStorage -import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.services.messaging.ReceivedMessage -import net.corda.node.services.messaging.TopicSession -import net.corda.node.utilities.AffinityExecutor -import net.corda.node.utilities.CordaPersistence -import net.corda.node.utilities.bufferUntilDatabaseCommit -import net.corda.node.utilities.wrapWithDatabaseTransaction -import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl -import net.corda.nodeapi.internal.serialization.withTokenContext -import org.apache.activemq.artemis.utils.ReusableLatch -import org.slf4j.Logger import rx.Observable -import rx.subjects.PublishSubject -import java.io.NotSerializableException -import java.util.* -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit.SECONDS -import javax.annotation.concurrent.ThreadSafe -import kotlin.collections.ArrayList /** - * A StateMachineManager is responsible for coordination and persistence of multiple [FlowStateMachineImpl] objects. + * A StateMachineManager is responsible for coordination and persistence of multiple [FlowStateMachine] objects. * Each such object represents an instantiation of a (two-party) flow that has reached a particular point. * - * An implementation of this class will persist state machines to long term storage so they can survive process restarts - * and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other - * (bad for performance, good for programmer mental health!). + * An implementation of this interface will persist state machines to long term storage so they can survive process + * restarts and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each + * other (bad for performance, good for programmer mental health!). * - * A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by - * a bytecode rewriting engine called Quasar, to ensure the code can be suspended and resumed at any point. - * - * The SMM will always invoke the flow fibers on the given [AffinityExecutor], regardless of which thread actually - * starts them via [add]. + * A flow is a class with a single call method. The call method and any others it invokes are rewritten by a bytecode + * rewriting engine called Quasar, to ensure the code can be suspended and resumed at any point. * * TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised * continuation is always unique? @@ -72,588 +28,51 @@ import kotlin.collections.ArrayList * TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt * TODO: Don't store all active flows in memory, load from the database on demand. */ -@ThreadSafe -class StateMachineManager(val serviceHub: ServiceHubInternal, - val checkpointStorage: CheckpointStorage, - val executor: AffinityExecutor, - val database: CordaPersistence, - private val unfinishedFibers: ReusableLatch = ReusableLatch(), - private val classloader: ClassLoader = javaClass.classLoader) { +interface StateMachineManager { + /** + * Starts the state machine manager, loading and starting the state machines in storage. + */ + fun start(tokenizableServices: List) + /** + * Stops the state machine manager gracefully, waiting until all but [allowedUnsuspendedFiberCount] flows reach the + * next checkpoint. + */ + fun stop(allowedUnsuspendedFiberCount: Int) - inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) - - companion object { - private val logger = loggerFor() - internal val sessionTopic = TopicSession("platform.session") - - init { - Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> - (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) - } - } - } + /** + * Starts a new flow. + * + * @param flowLogic The flow's code. + * @param flowInitiator The initiator of the flow. + */ + fun startFlow(flowLogic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): CordaFuture> + /** + * Represents an addition/removal of a state machine. + */ sealed class Change { abstract val logic: FlowLogic<*> - data class Add(override val logic: FlowLogic<*>) : Change() data class Removed(override val logic: FlowLogic<*>, val result: Try<*>) : Change() } - // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines - // property. - private class InnerState { - var started = false - val stateMachines = LinkedHashMap, Checkpoint>() - val changesPublisher = PublishSubject.create()!! - val fibersWaitingForLedgerCommit = HashMultimap.create>()!! + /** + * Returns the list of live state machines and a stream of subsequent additions/removals of them. + */ + fun track(): DataFeed>, Change> - fun notifyChangeObservers(change: Change) { - changesPublisher.bufferUntilDatabaseCommit().onNext(change) - } - } + /** + * The stream of additions/removals of flows. + */ + val changes: Observable - private val scheduler = FiberScheduler() - private val mutex = ThreadBox(InnerState()) - // This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore. - private val checkpointCheckerThread = if (serviceHub.configuration.devMode) Executors.newSingleThreadExecutor() else null - - @Volatile private var unrestorableCheckpoints = false - - // True if we're shutting down, so don't resume anything. - @Volatile private var stopping = false - // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. - private val liveFibers = ReusableLatch() - - // Monitoring support. - private val metrics = serviceHub.monitoringService.metrics - - init { - metrics.register("Flows.InFlight", Gauge { mutex.content.stateMachines.size }) - } - - private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") - private val totalStartedFlows = metrics.counter("Flows.Started") - private val totalFinishedFlows = metrics.counter("Flows.Finished") - - private val openSessions = ConcurrentHashMap() - private val recentlyClosedSessions = ConcurrentHashMap() - - internal val tokenizableServices = ArrayList() - // Context for tokenized services in checkpoints - private val serializationContext by lazy { - SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) - } - - fun findServices(predicate: (Any) -> Boolean) = tokenizableServices.filter(predicate) - - /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ - fun

, T> findStateMachines(flowClass: Class

): List>> { - return mutex.locked { - stateMachines.keys.mapNotNull { - flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast, FlowStateMachineImpl>(it.stateMachine).resultFuture } - } - } - } + /** + * Returns the currently live flows of type [flowClass], and their corresponding result future. + */ + fun > findStateMachines(flowClass: Class): List>> + /** + * Returns all currently live flows. + */ val allStateMachines: List> - get() = mutex.locked { stateMachines.keys.map { it.logic } } - - /** - * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number - * which may change across restarts. - * - * We use assignment here so that multiple subscribers share the same wrapped Observable. - */ - val changes: Observable = mutex.content.changesPublisher.wrapWithDatabaseTransaction() - - fun start() { - checkQuasarJavaAgentPresence() - restoreFibersFromCheckpoints() - listenToLedgerTransactions() - serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) } - } - - private fun checkQuasarJavaAgentPresence() { - check(SuspendableHelper.isJavaAgentActive(), { - """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. - #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") - }) - } - - private fun listenToLedgerTransactions() { - // Observe the stream of committed, validated transactions and resume fibers that are waiting for them. - serviceHub.validatedTransactions.updates.subscribe { stx -> - val hash = stx.id - val fibers: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } - if (fibers.isNotEmpty()) { - executor.executeASAP { - for (fiber in fibers) { - fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" } - fiber.waitingForResponse = null - resumeFiber(fiber) - } - } - } - } - } - - private fun decrementLiveFibers() { - liveFibers.countDown() - } - - private fun incrementLiveFibers() { - liveFibers.countUp() - } - - /** - * Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns, - * all Fibers have been suspended and checkpointed, or have completed. - * - * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. - */ - fun stop(allowedUnsuspendedFiberCount: Int = 0) { - require(allowedUnsuspendedFiberCount >= 0) - mutex.locked { - if (stopping) throw IllegalStateException("Already stopping!") - stopping = true - } - // Account for any expected Fibers in a test scenario. - liveFibers.countDown(allowedUnsuspendedFiberCount) - liveFibers.await() - checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) } - check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." } - } - - /** - * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and - * calls to [allStateMachines] - */ - fun track(): DataFeed>, Change> { - return mutex.locked { - DataFeed(stateMachines.keys.toList(), changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) - } - } - - private fun restoreFibersFromCheckpoints() { - mutex.locked { - checkpointStorage.forEach { checkpoint -> - // If a flow is added before start() then don't attempt to restore it - if (!stateMachines.containsValue(checkpoint)) { - deserializeFiber(checkpoint, logger)?.let { - initFiber(it) - stateMachines[it] = checkpoint - } - } - true - } - } - } - - private fun resumeRestoredFibers() { - mutex.locked { - started = true - stateMachines.keys.forEach { resumeRestoredFiber(it) } - } - serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ -> - executor.checkOnThread() - onSessionMessage(message) - } - } - - private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { - fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } - val waitingForResponse = fiber.waitingForResponse - if (waitingForResponse != null) { - if (waitingForResponse is WaitForLedgerCommit) { - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash) - } - if (stx != null) { - fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed") - fiber.waitingForResponse = null - resumeFiber(fiber) - } else { - fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}") - mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) } - } - } else { - fiber.logger.info("Restored, pending on receive") - } - } else { - resumeFiber(fiber) - } - } - - private fun onSessionMessage(message: ReceivedMessage) { - val sessionMessage = try { - message.data.deserialize() - } catch (ex: Exception) { - logger.error("Received corrupt SessionMessage data from ${message.peer}") - return - } - val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer) - if (sender != null) { - when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) - is SessionInit -> onSessionInit(sessionMessage, message, sender) - } - } else { - logger.error("Unknown peer ${message.peer} in $sessionMessage") - } - } - - private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { - val session = openSessions[message.recipientSessionId] - if (session != null) { - session.fiber.logger.trace { "Received $message on $session from $sender" } - if (session.retryable) { - if (message is SessionConfirm && session.state is FlowSessionState.Initiated) { - session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" } - return - } - if (message !is SessionConfirm) { - serviceHub.networkService.cancelRedelivery(session.ourSessionId) - } - } - if (message is SessionEnd) { - openSessions.remove(message.recipientSessionId) - } - session.receivedMessages += ReceivedSessionMessage(sender, message) - if (resumeOnMessage(message, session)) { - // It's important that we reset here and not after the fiber's resumed, in case we receive another message - // before then. - session.fiber.waitingForResponse = null - updateCheckpoint(session.fiber) - session.fiber.logger.trace { "Resuming due to $message" } - resumeFiber(session.fiber) - } - } else { - val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) - if (peerParty != null) { - if (message is SessionConfirm) { - logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) - } else { - logger.trace { "Ignoring session end message for already closed session: $message" } - } - } else { - logger.warn("Received a session message for unknown session: $message, from $sender") - } - } - } - - // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger - // commit but a counterparty flow has ended with an error (in which case our flow also has to end) - private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { - val waitingForResponse = session.fiber.waitingForResponse - return waitingForResponse?.shouldResume(message, session) ?: false - } - - private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { - logger.trace { "Received $sessionInit from $sender" } - val senderSessionId = sessionInit.initiatorSessionId - - fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) - - val (session, initiatedFlowFactory) = try { - val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) - val flowSession = FlowSessionImpl(sender) - val flow = initiatedFlowFactory.createFlow(flowSession) - val senderFlowVersion = when (initiatedFlowFactory) { - is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version - is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion - } - val session = FlowSessionInternal( - flow, - flowSession, - random63BitValue(), - sender, - FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) - if (sessionInit.firstPayload != null) { - session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) - } - openSessions[session.ourSessionId] = session - // TODO Perhaps the session-init will specificy which of our multiple identies to use, which we would have to - // double-check is actually ours. However, what if we want to control how our identities gets used? - val fiber = createFiber(flow, FlowInitiator.Peer(sender)) - flowSession.sessionFlow = flow - flowSession.stateMachine = fiber - fiber.openSessions[Pair(flow, sender)] = session - updateCheckpoint(fiber) - session to initiatedFlowFactory - } catch (e: SessionRejectException) { - logger.warn("${e.logMessage}: $sessionInit") - sendSessionReject(e.rejectMessage) - return - } catch (e: Exception) { - logger.warn("Couldn't start flow session from $sessionInit", e) - sendSessionReject("Unable to establish session") - return - } - - val (ourFlowVersion, appName) = when (initiatedFlowFactory) { - // The flow version for the core flows is the platform version - is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda" - is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName - } - - sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) - session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } - session.fiber.logger.trace { "Initiated from $sessionInit on $session" } - resumeFiber(session.fiber) - } - - private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { - val initiatingFlowClass = try { - Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java) - } catch (e: ClassNotFoundException) { - throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") - } catch (e: ClassCastException) { - throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") - } - return serviceHub.getFlowFactory(initiatingFlowClass) ?: - throw SessionRejectException("$initiatingFlowClass is not registered") - } - - private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { - return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) - } - - private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { - return try { - checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { - fromCheckpoint = true - } - } catch (t: Throwable) { - logger.error("Encountered unrestorable checkpoint!", t) - null - } - } - - private fun createFiber(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): FlowStateMachineImpl { - val fsm = FlowStateMachineImpl( - StateMachineRunId.createRandom(), - logic, - scheduler, - flowInitiator, - ourIdentity ?: serviceHub.myInfo.legalIdentities[0]) - initFiber(fsm) - return fsm - } - - private fun initFiber(fiber: FlowStateMachineImpl<*>) { - verifyFlowLogicIsSuspendable(fiber.logic) - fiber.database = database - fiber.serviceHub = serviceHub - fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity } - ?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity}) is not one of ours!") - fiber.actionOnSuspend = { ioRequest -> - updateCheckpoint(fiber) - // We commit on the fibers transaction that was copied across ThreadLocals during suspend - // This will free up the ThreadLocal so on return the caller can carry on with other transactions - fiber.commitTransaction() - processIORequest(ioRequest) - decrementLiveFibers() - } - fiber.actionOnEnd = { result, propagated -> - try { - mutex.locked { - stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } - notifyChangeObservers(Change.Removed(fiber.logic, result)) - } - endAllFiberSessions(fiber, result, propagated) - } finally { - fiber.commitTransaction() - decrementLiveFibers() - totalFinishedFlows.inc() - unfinishedFibers.countDown() - } - } - mutex.locked { - totalStartedFlows.inc() - unfinishedFibers.countUp() - notifyChangeObservers(Change.Add(fiber.logic)) - } - } - - private fun verifyFlowLogicIsSuspendable(logic: FlowLogic) { - // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's - // easy to forget to add this when creating a new flow, so we check here to give the user a better error. - // - // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which - // forwards to the void method and then returns Unit. However annotations do not get copied across to this - // bridge, so we have to do a more complex scan here. - val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } - if (call.getAnnotation(Suspendable::class.java) == null) { - throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") - } - } - - private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) { - openSessions.values.removeIf { session -> - if (session.fiber == fiber) { - session.endSession((result as? Try.Failure)?.exception, propagated) - true - } else { - false - } - } - } - - private fun FlowSessionInternal.endSession(exception: Throwable?, propagated: Boolean) { - val initiatedState = state as? FlowSessionState.Initiated ?: return - val sessionEnd = if (exception == null) { - NormalSessionEnd(initiatedState.peerSessionId) - } else { - val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { - // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only - // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with. - exception - } else { - null - } - ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) - } - sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) - recentlyClosedSessions[ourSessionId] = initiatedState.peerParty - } - - /** - * Kicks off a brand new state machine of the given class. - * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is - * restarted with checkpointed state machines in the storage service. - * - * Note that you must be on the [executor] thread. - */ - fun add(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): FlowStateMachineImpl { - // TODO: Check that logic has @Suspendable on its call method. - executor.checkOnThread() - val fiber = database.transaction { - val fiber = createFiber(logic, flowInitiator, ourIdentity) - updateCheckpoint(fiber) - fiber - } - // If we are not started then our checkpoint will be picked up during start - mutex.locked { - if (started) { - resumeFiber(fiber) - } - } - return fiber - } - - private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) { - check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } - val newCheckpoint = Checkpoint(serializeFiber(fiber)) - val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) } - if (previousCheckpoint != null) { - checkpointStorage.removeCheckpoint(previousCheckpoint) - } - checkpointStorage.addCheckpoint(newCheckpoint) - checkpointingMeter.mark() - - checkpointCheckerThread?.execute { - // Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have - // in our testing by failing any test where unrestorable checkpoints are created. - if (deserializeFiber(newCheckpoint, fiber.logger) == null) { - unrestorableCheckpoints = true - } - } - } - - private fun resumeFiber(fiber: FlowStateMachineImpl<*>) { - // Avoid race condition when setting stopping to true and then checking liveFibers - incrementLiveFibers() - if (!stopping) { - executor.executeASAP { - fiber.resume(scheduler) - } - } else { - fiber.logger.trace("Not resuming as SMM is stopping.") - decrementLiveFibers() - } - } - - private fun processIORequest(ioRequest: FlowIORequest) { - executor.checkOnThread() - when (ioRequest) { - is SendRequest -> processSendRequest(ioRequest) - is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest) - is Sleep -> processSleepRequest(ioRequest) - } - } - - private fun processSendRequest(ioRequest: SendRequest) { - val retryId = if (ioRequest.message is SessionInit) { - with(ioRequest.session) { - openSessions[ourSessionId] = this - if (retryable) ourSessionId else null - } - } else null - sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) - if (ioRequest !is ReceiveRequest<*>) { - // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - resumeFiber(ioRequest.session.fiber) - } - } - - private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) { - // Is it already committed? - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(ioRequest.hash) - } - if (stx != null) { - resumeFiber(ioRequest.fiber) - } else { - // No, then register to wait. - // - // We assume this code runs on the server thread, which is the only place transactions are committed - // currently. When we liberalise our threading somewhat, handing of wait requests will need to be - // reworked to make the wait atomic in another way. Otherwise there is a race between checking the - // database and updating the waiting list. - mutex.locked { - fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber - } - } - } - - private fun processSleepRequest(ioRequest: Sleep) { - // Resume the fiber now we have checkpointed, so we can sleep on the Fiber. - resumeFiber(ioRequest.fiber) - } - - private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) { - val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) - ?: throw IllegalArgumentException("Don't know about party $party") - val address = serviceHub.networkService.getAddressOfParty(partyInfo) - val logger = fiber?.logger ?: logger - logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" } - - val serialized = try { - message.serialize() - } catch (e: Exception) { - when (e) { - // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. - is KryoException, - is NotSerializableException -> { - if (message !is ErrorSessionEnd || message.errorResponse == null) throw e - logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " + - "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) - // The subclass may have overridden toString so we use that - val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message } - message.copy(errorResponse = FlowException(exMessage)).serialize() - } - else -> throw e - } - } - - serviceHub.networkService.apply { - send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId) - } - } -} - -class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) { - constructor(message: String) : this(message, message) -} +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt new file mode 100644 index 0000000000..7fdd7f920c --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt @@ -0,0 +1,634 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.FiberExecutorScheduler +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.fibers.instrument.SuspendableHelper +import co.paralleluniverse.strands.Strand +import com.codahale.metrics.Gauge +import com.esotericsoftware.kryo.KryoException +import com.google.common.collect.HashMultimap +import com.google.common.util.concurrent.MoreExecutors +import net.corda.core.CordaException +import net.corda.core.concurrent.CordaFuture +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.random63BitValue +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.internal.* +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.messaging.DataFeed +import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.core.utilities.Try +import net.corda.core.utilities.debug +import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.trace +import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.services.api.Checkpoint +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.messaging.TopicSession +import net.corda.node.utilities.AffinityExecutor +import net.corda.node.utilities.CordaPersistence +import net.corda.node.utilities.bufferUntilDatabaseCommit +import net.corda.node.utilities.wrapWithDatabaseTransaction +import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl +import net.corda.nodeapi.internal.serialization.withTokenContext +import org.apache.activemq.artemis.utils.ReusableLatch +import org.slf4j.Logger +import rx.Observable +import rx.subjects.PublishSubject +import java.io.NotSerializableException +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit.SECONDS +import javax.annotation.concurrent.ThreadSafe + +/** + * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which + * thread actually starts them via [startFlow]. + */ +@ThreadSafe +class StateMachineManagerImpl( + val serviceHub: ServiceHubInternal, + val checkpointStorage: CheckpointStorage, + val executor: AffinityExecutor, + val database: CordaPersistence, + private val unfinishedFibers: ReusableLatch = ReusableLatch(), + private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader +) : StateMachineManager { + inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) + + companion object { + private val logger = loggerFor() + internal val sessionTopic = TopicSession("platform.session") + + init { + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) + } + } + } + + // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines + // property. + private class InnerState { + var started = false + val stateMachines = LinkedHashMap, Checkpoint>() + val changesPublisher = PublishSubject.create()!! + val fibersWaitingForLedgerCommit = HashMultimap.create>()!! + + fun notifyChangeObservers(change: StateMachineManager.Change) { + changesPublisher.bufferUntilDatabaseCommit().onNext(change) + } + } + + private val scheduler = FiberScheduler() + private val mutex = ThreadBox(InnerState()) + // This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore. + private val checkpointCheckerThread = if (serviceHub.configuration.devMode) Executors.newSingleThreadExecutor() else null + + @Volatile private var unrestorableCheckpoints = false + + // True if we're shutting down, so don't resume anything. + @Volatile private var stopping = false + // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. + private val liveFibers = ReusableLatch() + + // Monitoring support. + private val metrics = serviceHub.monitoringService.metrics + + init { + metrics.register("Flows.InFlight", Gauge { mutex.content.stateMachines.size }) + } + + private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") + private val totalStartedFlows = metrics.counter("Flows.Started") + private val totalFinishedFlows = metrics.counter("Flows.Finished") + + private val openSessions = ConcurrentHashMap() + private val recentlyClosedSessions = ConcurrentHashMap() + + // Context for tokenized services in checkpoints + private lateinit var tokenizableServices: List + private val serializationContext by lazy { + SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) + } + + /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ + override fun > findStateMachines(flowClass: Class): List>> { + return mutex.locked { + stateMachines.keys.mapNotNull { + flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast, FlowStateMachineImpl<*>>(it.stateMachine).resultFuture } + } + } + } + + override val allStateMachines: List> + get() = mutex.locked { stateMachines.keys.map { it.logic } } + + /** + * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number + * which may change across restarts. + * + * We use assignment here so that multiple subscribers share the same wrapped Observable. + */ + override val changes: Observable = mutex.content.changesPublisher.wrapWithDatabaseTransaction() + + override fun start(tokenizableServices: List) { + this.tokenizableServices = tokenizableServices + checkQuasarJavaAgentPresence() + restoreFibersFromCheckpoints() + listenToLedgerTransactions() + serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) } + } + + private fun checkQuasarJavaAgentPresence() { + check(SuspendableHelper.isJavaAgentActive(), { + """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. + #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") + }) + } + + private fun listenToLedgerTransactions() { + // Observe the stream of committed, validated transactions and resume fibers that are waiting for them. + serviceHub.validatedTransactions.updates.subscribe { stx -> + val hash = stx.id + val fibers: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } + if (fibers.isNotEmpty()) { + executor.executeASAP { + for (fiber in fibers) { + fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" } + fiber.waitingForResponse = null + resumeFiber(fiber) + } + } + } + } + } + + private fun decrementLiveFibers() { + liveFibers.countDown() + } + + private fun incrementLiveFibers() { + liveFibers.countUp() + } + + /** + * Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns, + * all Fibers have been suspended and checkpointed, or have completed. + * + * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. + */ + override fun stop(allowedUnsuspendedFiberCount: Int) { + require(allowedUnsuspendedFiberCount >= 0) + mutex.locked { + if (stopping) throw IllegalStateException("Already stopping!") + stopping = true + } + // Account for any expected Fibers in a test scenario. + liveFibers.countDown(allowedUnsuspendedFiberCount) + liveFibers.await() + checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) } + check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." } + } + + /** + * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and + * calls to [allStateMachines] + */ + override fun track(): DataFeed>, StateMachineManager.Change> { + return mutex.locked { + DataFeed(stateMachines.keys.map { it.logic }, changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + } + } + + private fun restoreFibersFromCheckpoints() { + mutex.locked { + checkpointStorage.forEach { checkpoint -> + // If a flow is added before start() then don't attempt to restore it + if (!stateMachines.containsValue(checkpoint)) { + deserializeFiber(checkpoint, logger)?.let { + initFiber(it) + stateMachines[it] = checkpoint + } + } + true + } + } + } + + private fun resumeRestoredFibers() { + mutex.locked { + started = true + stateMachines.keys.forEach { resumeRestoredFiber(it) } + } + serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ -> + executor.checkOnThread() + onSessionMessage(message) + } + } + + private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { + fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } + val waitingForResponse = fiber.waitingForResponse + if (waitingForResponse != null) { + if (waitingForResponse is WaitForLedgerCommit) { + val stx = database.transaction { + serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash) + } + if (stx != null) { + fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed") + fiber.waitingForResponse = null + resumeFiber(fiber) + } else { + fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}") + mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) } + } + } else { + fiber.logger.info("Restored, pending on receive") + } + } else { + resumeFiber(fiber) + } + } + + private fun onSessionMessage(message: ReceivedMessage) { + val sessionMessage = try { + message.data.deserialize() + } catch (ex: Exception) { + logger.error("Received corrupt SessionMessage data from ${message.peer}") + return + } + val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer) + if (sender != null) { + when (sessionMessage) { + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) + is SessionInit -> onSessionInit(sessionMessage, message, sender) + } + } else { + logger.error("Unknown peer ${message.peer} in $sessionMessage") + } + } + + private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { + val session = openSessions[message.recipientSessionId] + if (session != null) { + session.fiber.logger.trace { "Received $message on $session from $sender" } + if (session.retryable) { + if (message is SessionConfirm && session.state is FlowSessionState.Initiated) { + session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" } + return + } + if (message !is SessionConfirm) { + serviceHub.networkService.cancelRedelivery(session.ourSessionId) + } + } + if (message is SessionEnd) { + openSessions.remove(message.recipientSessionId) + } + session.receivedMessages += ReceivedSessionMessage(sender, message) + if (resumeOnMessage(message, session)) { + // It's important that we reset here and not after the fiber's resumed, in case we receive another message + // before then. + session.fiber.waitingForResponse = null + updateCheckpoint(session.fiber) + session.fiber.logger.trace { "Resuming due to $message" } + resumeFiber(session.fiber) + } + } else { + val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) + if (peerParty != null) { + if (message is SessionConfirm) { + logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } + sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) + } else { + logger.trace { "Ignoring session end message for already closed session: $message" } + } + } else { + logger.warn("Received a session message for unknown session: $message, from $sender") + } + } + } + + // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger + // commit but a counterparty flow has ended with an error (in which case our flow also has to end) + private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { + val waitingForResponse = session.fiber.waitingForResponse + return waitingForResponse?.shouldResume(message, session) ?: false + } + + private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { + logger.trace { "Received $sessionInit from $sender" } + val senderSessionId = sessionInit.initiatorSessionId + + fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) + + val (session, initiatedFlowFactory) = try { + val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) + val flowSession = FlowSessionImpl(sender) + val flow = initiatedFlowFactory.createFlow(flowSession) + val senderFlowVersion = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version + is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion + } + val session = FlowSessionInternal( + flow, + flowSession, + random63BitValue(), + sender, + FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) + if (sessionInit.firstPayload != null) { + session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) + } + openSessions[session.ourSessionId] = session + // TODO Perhaps the session-init will specificy which of our multiple identies to use, which we would have to + // double-check is actually ours. However, what if we want to control how our identities gets used? + val fiber = createFiber(flow, FlowInitiator.Peer(sender)) + flowSession.sessionFlow = flow + flowSession.stateMachine = fiber + fiber.openSessions[Pair(flow, sender)] = session + updateCheckpoint(fiber) + session to initiatedFlowFactory + } catch (e: SessionRejectException) { + logger.warn("${e.logMessage}: $sessionInit") + sendSessionReject(e.rejectMessage) + return + } catch (e: Exception) { + logger.warn("Couldn't start flow session from $sessionInit", e) + sendSessionReject("Unable to establish session") + return + } + + val (ourFlowVersion, appName) = when (initiatedFlowFactory) { + // The flow version for the core flows is the platform version + is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda" + is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName + } + + sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) + session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } + session.fiber.logger.trace { "Initiated from $sessionInit on $session" } + resumeFiber(session.fiber) + } + + private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { + val initiatingFlowClass = try { + Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java) + } catch (e: ClassNotFoundException) { + throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") + } catch (e: ClassCastException) { + throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") + } + return serviceHub.getFlowFactory(initiatingFlowClass) ?: + throw SessionRejectException("$initiatingFlowClass is not registered") + } + + private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { + return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) + } + + private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { + return try { + checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { + fromCheckpoint = true + } + } catch (t: Throwable) { + logger.error("Encountered unrestorable checkpoint!", t) + null + } + } + + private fun createFiber(logic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party? = null): FlowStateMachineImpl { + val fsm = FlowStateMachineImpl( + StateMachineRunId.createRandom(), + logic, + scheduler, + flowInitiator, + ourIdentity ?: serviceHub.myInfo.legalIdentities[0]) + initFiber(fsm) + return fsm + } + + private fun initFiber(fiber: FlowStateMachineImpl<*>) { + verifyFlowLogicIsSuspendable(fiber.logic) + fiber.database = database + fiber.serviceHub = serviceHub + fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity } + ?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity}) is not one of ours!") + fiber.actionOnSuspend = { ioRequest -> + updateCheckpoint(fiber) + // We commit on the fibers transaction that was copied across ThreadLocals during suspend + // This will free up the ThreadLocal so on return the caller can carry on with other transactions + fiber.commitTransaction() + processIORequest(ioRequest) + decrementLiveFibers() + } + fiber.actionOnEnd = { result, propagated -> + try { + mutex.locked { + stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } + notifyChangeObservers(StateMachineManager.Change.Removed(fiber.logic, result)) + } + endAllFiberSessions(fiber, result, propagated) + } finally { + fiber.commitTransaction() + decrementLiveFibers() + totalFinishedFlows.inc() + unfinishedFibers.countDown() + } + } + mutex.locked { + totalStartedFlows.inc() + unfinishedFibers.countUp() + notifyChangeObservers(StateMachineManager.Change.Add(fiber.logic)) + } + } + + private fun verifyFlowLogicIsSuspendable(logic: FlowLogic) { + // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's + // easy to forget to add this when creating a new flow, so we check here to give the user a better error. + // + // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which + // forwards to the void method and then returns Unit. However annotations do not get copied across to this + // bridge, so we have to do a more complex scan here. + val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } + if (call.getAnnotation(Suspendable::class.java) == null) { + throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") + } + } + + private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) { + openSessions.values.removeIf { session -> + if (session.fiber == fiber) { + session.endSession((result as? Try.Failure)?.exception, propagated) + true + } else { + false + } + } + } + + private fun FlowSessionInternal.endSession(exception: Throwable?, propagated: Boolean) { + val initiatedState = state as? FlowSessionState.Initiated ?: return + val sessionEnd = if (exception == null) { + NormalSessionEnd(initiatedState.peerSessionId) + } else { + val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { + // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only + // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with. + exception + } else { + null + } + ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) + } + sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) + recentlyClosedSessions[ourSessionId] = initiatedState.peerParty + } + + /** + * Kicks off a brand new state machine of the given class. + * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is + * restarted with checkpointed state machines in the storage service. + * + * Note that you must be on the [executor] thread. + */ + override fun startFlow(flowLogic: FlowLogic, flowInitiator: FlowInitiator, ourIdentity: Party?): CordaFuture> { + // TODO: Check that logic has @Suspendable on its call method. + executor.checkOnThread() + val fiber = database.transaction { + val fiber = createFiber(flowLogic, flowInitiator, ourIdentity) + updateCheckpoint(fiber) + fiber + } + // If we are not started then our checkpoint will be picked up during start + mutex.locked { + if (started) { + resumeFiber(fiber) + } + } + return doneFuture(fiber) + } + + private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) { + check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } + val newCheckpoint = Checkpoint(serializeFiber(fiber)) + val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) } + if (previousCheckpoint != null) { + checkpointStorage.removeCheckpoint(previousCheckpoint) + } + checkpointStorage.addCheckpoint(newCheckpoint) + checkpointingMeter.mark() + + checkpointCheckerThread?.execute { + // Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have + // in our testing by failing any test where unrestorable checkpoints are created. + if (deserializeFiber(newCheckpoint, fiber.logger) == null) { + unrestorableCheckpoints = true + } + } + } + + private fun resumeFiber(fiber: FlowStateMachineImpl<*>) { + // Avoid race condition when setting stopping to true and then checking liveFibers + incrementLiveFibers() + if (!stopping) { + executor.executeASAP { + fiber.resume(scheduler) + } + } else { + fiber.logger.trace("Not resuming as SMM is stopping.") + decrementLiveFibers() + } + } + + private fun processIORequest(ioRequest: FlowIORequest) { + executor.checkOnThread() + when (ioRequest) { + is SendRequest -> processSendRequest(ioRequest) + is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest) + is Sleep -> processSleepRequest(ioRequest) + } + } + + private fun processSendRequest(ioRequest: SendRequest) { + val retryId = if (ioRequest.message is SessionInit) { + with(ioRequest.session) { + openSessions[ourSessionId] = this + if (retryable) ourSessionId else null + } + } else null + sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) + if (ioRequest !is ReceiveRequest<*>) { + // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. + resumeFiber(ioRequest.session.fiber) + } + } + + private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) { + // Is it already committed? + val stx = database.transaction { + serviceHub.validatedTransactions.getTransaction(ioRequest.hash) + } + if (stx != null) { + resumeFiber(ioRequest.fiber) + } else { + // No, then register to wait. + // + // We assume this code runs on the server thread, which is the only place transactions are committed + // currently. When we liberalise our threading somewhat, handing of wait requests will need to be + // reworked to make the wait atomic in another way. Otherwise there is a race between checking the + // database and updating the waiting list. + mutex.locked { + fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber + } + } + } + + private fun processSleepRequest(ioRequest: Sleep) { + // Resume the fiber now we have checkpointed, so we can sleep on the Fiber. + resumeFiber(ioRequest.fiber) + } + + private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) { + val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) + ?: throw IllegalArgumentException("Don't know about party $party") + val address = serviceHub.networkService.getAddressOfParty(partyInfo) + val logger = fiber?.logger ?: logger + logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" } + + val serialized = try { + message.serialize() + } catch (e: Exception) { + when (e) { + // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. + is KryoException, + is NotSerializableException -> { + if (message !is ErrorSessionEnd || message.errorResponse == null) throw e + logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " + + "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) + // The subclass may have overridden toString so we use that + val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message } + message.copy(errorResponse = FlowException(exMessage)).serialize() + } + else -> throw e + } + } + + serviceHub.networkService.apply { + send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId) + } + } +} + +class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) { + constructor(message: String) : this(message, message) +} diff --git a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt index 75c4f85a5c..d818e30b8f 100644 --- a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt +++ b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt @@ -20,6 +20,7 @@ import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineUpdate +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.node.internal.Node import net.corda.node.internal.StartedNode @@ -234,7 +235,7 @@ object InteractiveShell { val clazz: Class> = uncheckedCast(matches.single()) try { // TODO Flow invocation should use startFlowDynamic. - val fsm = runFlowFromString({ node.services.startFlow(it, FlowInitiator.Shell) }, inputData, clazz) + val fsm = runFlowFromString({ node.services.startFlow(it, FlowInitiator.Shell).getOrThrow() }, inputData, clazz) // Show the progress tracker on the console until the flow completes or is interrupted with a // Ctrl-C keypress. val latch = CountDownLatch(1) diff --git a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt index a1eb79aefb..622fc9667c 100644 --- a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt +++ b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt @@ -52,7 +52,7 @@ class InteractiveShellTest { private fun check(input: String, expected: String) { var output: DummyFSM? = null InteractiveShell.runFlowFromString({ DummyFSM(it as FlowA).apply { output = this } }, input, FlowA::class.java, om) - assertEquals(expected, output!!.flowA.a, input) + assertEquals(expected, output!!.logic.a, input) } @Test @@ -83,5 +83,5 @@ class InteractiveShellTest { @Test fun party() = check("party: \"${MEGA_CORP.name}\"", MEGA_CORP.name.toString()) - class DummyFSM(val flowA: FlowA) : FlowStateMachine by rigorousMock() + class DummyFSM(override val logic: FlowA) : FlowStateMachine by rigorousMock() } diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 53c7600ca3..d8a0c51096 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -27,6 +27,7 @@ import net.corda.node.services.network.NetworkMapCacheImpl import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.StateMachineManager +import net.corda.node.services.statemachine.StateMachineManagerImpl import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.CordaPersistence @@ -113,14 +114,14 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { doReturn(this@NodeSchedulerServiceTest).whenever(it).testReference } smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) - mockSMM = StateMachineManager(services, DBCheckpointStorage(), smmExecutor, database) + mockSMM = StateMachineManagerImpl(services, DBCheckpointStorage(), smmExecutor, database) scheduler = NodeSchedulerService(testClock, database, FlowStarterImpl(smmExecutor, mockSMM), stateLoader, schedulerGatedExecutor, serverThread = smmExecutor) mockSMM.changes.subscribe { change -> if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) { smmHasRemovedAllFlows.countDown() } } - mockSMM.start() + mockSMM.start(emptyList()) scheduler.start() } } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index e65592f70b..52b79b2dee 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -64,6 +64,10 @@ class FlowFrameworkTests { private lateinit var alice: Party private lateinit var bob: Party + private fun StartedNode<*>.flushSmm() { + (this.smm as StateMachineManagerImpl).executor.flush() + } + @Before fun start() { mockNet = MockNetwork(servicePeerAllocationStrategy = RoundRobin(), cordappPackages = listOf("net.corda.finance.contracts", "net.corda.testing.contracts")) @@ -166,7 +170,7 @@ class FlowFrameworkTests { val restoredFlow = charlieNode.getSingleFlow().first assertEquals(false, restoredFlow.flowStarted) // Not started yet as no network activity has been allowed yet mockNet.runNetwork() // Allow network map messages to flow - charlieNode.smm.executor.flush() + charlieNode.flushSmm() assertEquals(true, restoredFlow.flowStarted) // Now we should have run the flow and hopefully cleared the init checkpoint charlieNode.internals.disableDBCloseOnStop() charlieNode.services.networkMapCache.clearNetworkMapCache() // zap persisted NetworkMapCache to force use of network. @@ -175,7 +179,7 @@ class FlowFrameworkTests { // Now it is completed the flow should leave no Checkpoint. charlieNode = mockNet.createNode(charlieNode.internals.id) mockNet.runNetwork() // Allow network map messages to flow - charlieNode.smm.executor.flush() + charlieNode.flushSmm() assertTrue(charlieNode.smm.findStateMachines(NoOpFlow::class.java).isEmpty()) } @@ -184,7 +188,7 @@ class FlowFrameworkTests { aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow // Make sure the add() has finished initial processing. - bobNode.smm.executor.flush() + bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() bobNode.dispose() // kill receiver val restoredFlow = bobNode.restartAndGetRestoredFlow() @@ -210,7 +214,7 @@ class FlowFrameworkTests { assertEquals(1, bobNode.checkpointStorage.checkpoints().size) } // Make sure the add() has finished initial processing. - bobNode.smm.executor.flush() + bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() // Restart node and thus reload the checkpoint and resend the message with same UUID bobNode.dispose() @@ -223,7 +227,7 @@ class FlowFrameworkTests { val (firstAgain, fut1) = node2b.getSingleFlow() // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. mockNet.runNetwork() - node2b.smm.executor.flush() + node2b.flushSmm() fut1.getOrThrow() val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } @@ -731,7 +735,7 @@ class FlowFrameworkTests { private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { services.networkService.apply { val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) - send(createMessage(StateMachineManager.sessionTopic, message.serialize().bytes), address) + send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address) } } @@ -755,7 +759,7 @@ class FlowFrameworkTests { } private fun Observable.toSessionTransfers(): Observable { - return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { + return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map { val from = it.sender.id val message = it.message.data.deserialize() SessionTransfer(from, sanitise(message), it.recipients) diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt index 3953c876b0..40b84bd462 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt @@ -200,13 +200,13 @@ class NodeInterestRatesTest : TestDependencyInjectionBase() { @Test fun `network tearoff`() { - val mockNet = MockNetwork(initialiseSerialization = false, cordappPackages = listOf("net.corda.finance.contracts")) + val mockNet = MockNetwork(initialiseSerialization = false, cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs")) val n1 = mockNet.createNotaryNode() val oracleNode = mockNet.createNode().apply { internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) internals.registerInitiatedFlow(NodeInterestRates.FixSignHandler::class.java) database.transaction { - installCordaService(NodeInterestRates.Oracle::class.java).knownFixes = TEST_DATA + internals.findTokenizableService(NodeInterestRates.Oracle::class.java)!!.knownFixes = TEST_DATA } } val tx = makePartialTX() diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt index 19c5750d65..283eb64eda 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt @@ -12,11 +12,14 @@ import net.corda.node.internal.StartedNode import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.statemachine.StateMachineManager import net.corda.nodeapi.internal.ServiceInfo -import net.corda.testing.* +import net.corda.testing.DUMMY_MAP +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.DUMMY_REGULATOR import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockNetwork import net.corda.testing.node.TestClock import net.corda.testing.node.setTo +import net.corda.testing.testNodeConfiguration import rx.Observable import rx.subjects.PublishSubject import java.math.BigInteger @@ -118,7 +121,7 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, registerInitiatedFlow(NodeInterestRates.FixSignHandler::class.java) javaClass.classLoader.getResourceAsStream("net/corda/irs/simulation/example.rates.txt").use { database.transaction { - installCordaService(NodeInterestRates.Oracle::class.java).uploadFixes(it.reader().readText()) + findTokenizableService(NodeInterestRates.Oracle::class.java)!!.uploadFixes(it.reader().readText()) } } } @@ -143,7 +146,7 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, val mockNet = MockNetwork( networkSendManuallyPumped = networkSendManuallyPumped, threadPerNode = runAsync, - cordappPackages = listOf("net.corda.irs.contract", "net.corda.finance.contract")) + cordappPackages = listOf("net.corda.irs.contract", "net.corda.finance.contract", "net.corda.irs")) // This one must come first. val networkMap = mockNet.startNetworkMapNode(nodeFactory = NetworkMapNodeFactory) val notary = mockNet.createNotaryNode(validating = false, nodeFactory = NotaryNodeFactory) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index a44cf1c4a5..b009106082 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -363,14 +363,22 @@ class InMemoryMessagingNetwork( state.locked { check(handlers.remove(registration as Handler)) } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { check(running) msgSend(this, message, target) + acknowledgementHandler?.invoke() if (!sendManuallyPumped) { pumpSend(false) } } + override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + for ((message, target, retryId, sequenceKey) in addressedMessages) { + send(message, target, retryId, sequenceKey, null) + } + acknowledgementHandler?.invoke() + } + override fun stop() { if (backgroundThread != null) { backgroundThread.interrupt() diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt index 9efcc1cefb..60d138248f 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -271,7 +271,7 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete throw IllegalStateException("Unable to enumerate all nodes in BFT cluster.") } clusterNodes.forEach { - val notaryService = it.started!!.smm.findServices { it is BFTNonValidatingNotaryService }.single() as BFTNonValidatingNotaryService + val notaryService = it.findTokenizableService(BFTNonValidatingNotaryService::class.java)!! notaryService.waitUntilReplicaHasInitialized() } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt index fa881e1703..b0c962232a 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -30,6 +30,7 @@ import net.corda.node.services.persistence.HibernateConfiguration import net.corda.node.services.persistence.InMemoryStateMachineRecordedTransactionMappingStorage import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.NodeSchemaService +import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.CordaPersistence