From 3a17d4726f0804ab3c83b658a5927f2f3c46d387 Mon Sep 17 00:00:00 2001 From: bpaunescu Date: Fri, 20 Apr 2018 10:52:00 +0100 Subject: [PATCH 1/9] ENT-1775: reworked client to handle failover in HA mode instead of relying on artemis (#759) * ENT-1775: reworked client to handle failover in HA mode instead of relying on artemis * ENT-1775: address PR comments --- .../net/corda/client/rpc/RPCStabilityTests.kt | 90 +++++++++++ .../net/corda/client/rpc/CordaRPCClient.kt | 60 ++++++-- .../corda/client/rpc/internal/RPCClient.kt | 20 ++- .../rpc/internal/RPCClientProxyHandler.kt | 140 ++++++++++++++---- .../corda/testing/node/internal/RPCDriver.kt | 33 +++++ 5 files changed, 304 insertions(+), 39 deletions(-) diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index e7452d49ab..d9602f0c43 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -326,6 +326,96 @@ class RPCStabilityTests { } } + interface ServerOps : RPCOps { + fun serverId(): String + } + + @Test + fun `client connects to first available server`() { + rpcDriver { + val ops = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server" + } + val serverFollower = shutdownManager.follower() + val serverAddress = startRpcServer(ops = ops).getOrThrow().broker.hostAndPort!! + serverFollower.unfollow() + + val clientFollower = shutdownManager.follower() + val client = startRpcClient(listOf(NetworkHostAndPort("localhost", 12345), serverAddress, NetworkHostAndPort("localhost", 54321))).getOrThrow() + clientFollower.unfollow() + + assertEquals("server", client.serverId()) + + clientFollower.shutdown() // Driver would do this after the new server, causing hang. + } + } + + @Test + fun `3 server failover`() { + rpcDriver { + val ops1 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server1" + } + val ops2 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server2" + } + val ops3 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server3" + } + val serverFollower1 = shutdownManager.follower() + val server1 = startRpcServer(ops = ops1).getOrThrow() + serverFollower1.unfollow() + + val serverFollower2 = shutdownManager.follower() + val server2 = startRpcServer(ops = ops2).getOrThrow() + serverFollower2.unfollow() + + val serverFollower3 = shutdownManager.follower() + val server3 = startRpcServer(ops = ops3).getOrThrow() + serverFollower3.unfollow() + val servers = mutableMapOf("server1" to serverFollower1, "server2" to serverFollower2, "server3" to serverFollower3) + + val clientFollower = shutdownManager.follower() + val client = startRpcClient(listOf(server1.broker.hostAndPort!!, server2.broker.hostAndPort!!, server3.broker.hostAndPort!!)).getOrThrow() + clientFollower.unfollow() + + var response = client.serverId() + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + //failover will take some time + while (true) { + try { + response = client.serverId() + break + } catch (e: RPCException) {} + } + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + while (true) { + try { + response = client.serverId() + break + } catch (e: RPCException) {} + } + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + assertTrue(servers.isEmpty()) + + clientFollower.shutdown() // Driver would do this after the new server, causing hang. + + } + } + interface TrackSubscriberOps : RPCOps { fun subscribe(): Observable } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index ed4ad14a23..a378808068 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -93,19 +93,34 @@ interface CordaRPCClientConfiguration { * [CordaRPCClientConfiguration]. While attempting failover, current and future RPC calls will throw * [RPCException] and previously returned observables will call onError(). * + * If the client was created using a list of hosts, automatic failover will occur(the servers have to be started in HA mode) + * * @param hostAndPort The network address to connect to. * @param configuration An optional configuration used to tweak client behaviour. * @param sslConfiguration An optional [SSLConfiguration] used to enable secure communication with the server. + * @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode. + * The client will attempt to connect to a live server by trying each address in the list. If the servers are not in + * HA mode, the client will round-robin from the beginning of the list and try all servers. */ class CordaRPCClient private constructor( - hostAndPort: NetworkHostAndPort, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), - sslConfiguration: SSLConfiguration? = null, - classLoader: ClassLoader? = null + private val hostAndPort: NetworkHostAndPort, + private val configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + private val sslConfiguration: SSLConfiguration? = null, + private val classLoader: ClassLoader? = null, + private val haAddressPool: List = emptyList() ) { @JvmOverloads constructor(hostAndPort: NetworkHostAndPort, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default()) : this(hostAndPort, configuration, null) + /** + * @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode. + * The client will attempt to connect to a live server by trying each address in the list. If the servers are not in + * HA mode, the client will round-robin from the beginning of the list and try all servers. + * @param configuration An optional configuration used to tweak client behaviour. + */ + @JvmOverloads + constructor(haAddressPool: List, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default()) : this(haAddressPool.first(), configuration, null, null, haAddressPool) + companion object { internal fun createWithSsl( hostAndPort: NetworkHostAndPort, @@ -115,6 +130,14 @@ class CordaRPCClient private constructor( return CordaRPCClient(hostAndPort, configuration, sslConfiguration) } + internal fun createWithSsl( + haAddressPool: List, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + sslConfiguration: SSLConfiguration? = null + ): CordaRPCClient { + return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, null, haAddressPool) + } + internal fun createWithSslAndClassLoader( hostAndPort: NetworkHostAndPort, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), @@ -123,6 +146,15 @@ class CordaRPCClient private constructor( ): CordaRPCClient { return CordaRPCClient(hostAndPort, configuration, sslConfiguration, classLoader) } + + internal fun createWithSslAndClassLoader( + haAddressPool: List, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + sslConfiguration: SSLConfiguration? = null, + classLoader: ClassLoader? = null + ): CordaRPCClient { + return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, classLoader, haAddressPool) + } } init { @@ -137,11 +169,19 @@ class CordaRPCClient private constructor( } } - private val rpcClient = RPCClient( - tcpTransport(ConnectionDirection.Outbound(), hostAndPort, config = sslConfiguration), - configuration, - if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT - ) + private fun getRpcClient() : RPCClient { + return if (haAddressPool.isEmpty()) { + RPCClient( + tcpTransport(ConnectionDirection.Outbound(), hostAndPort, config = sslConfiguration), + configuration, + if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT) + } else { + RPCClient(haAddressPool, + sslConfiguration, + configuration, + if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT) + } + } /** * Logs in to the target server and returns an active connection. The returned connection is a [java.io.Closeable] @@ -169,7 +209,7 @@ class CordaRPCClient private constructor( * @throws RPCException if the server version is too low or if the server isn't reachable within a reasonable timeout. */ fun start(username: String, password: String, externalTrace: Trace?, impersonatedActor: Actor?): CordaRPCConnection { - return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password, externalTrace, impersonatedActor)) + return CordaRPCConnection(getRpcClient().start(CordaRPCOps::class.java, username, password, externalTrace, impersonatedActor)) } /** diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index e8f33d284f..24096c0952 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -14,6 +14,7 @@ import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.* import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport +import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransportsFromList import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.RPCApi import net.corda.nodeapi.internal.config.SSLConfiguration @@ -60,7 +61,8 @@ data class CordaRPCClientConfigurationImpl( class RPCClient( val transport: TransportConfiguration, val rpcConfiguration: CordaRPCClientConfiguration = CordaRPCClientConfigurationImpl.default, - val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT, + val haPoolTransportConfigurations: List = emptyList() ) { constructor( hostAndPort: NetworkHostAndPort, @@ -69,6 +71,14 @@ class RPCClient( serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration, serializationContext) + constructor( + haAddressPool: List, + sslConfiguration: SSLConfiguration? = null, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfigurationImpl.default, + serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + ) : this(tcpTransport(ConnectionDirection.Outbound(), haAddressPool.first(), sslConfiguration), + configuration, serializationContext, tcpTransportsFromList(ConnectionDirection.Outbound(), haAddressPool, sslConfiguration)) + companion object { private val log = contextLogger() } @@ -83,11 +93,15 @@ class RPCClient( return log.logElapsedTime("Startup") { val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}") - val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply { + val serverLocator = (if (haPoolTransportConfigurations.isEmpty()) { + ActiveMQClient.createServerLocatorWithoutHA(transport) + } else { + ActiveMQClient.createServerLocatorWithoutHA(*haPoolTransportConfigurations.toTypedArray()) + }).apply { retryInterval = rpcConfiguration.connectionRetryInterval.toMillis() retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis() - reconnectAttempts = rpcConfiguration.maxReconnectAttempts + reconnectAttempts = if (haPoolTransportConfigurations.isEmpty()) rpcConfiguration.maxReconnectAttempts else 0 minLargeMessageSize = rpcConfiguration.maxFileSize isUseGlobalPools = nodeSerializationEnv != null } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index dc07b1ada2..8852a096cd 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle import net.corda.core.internal.ThreadBox +import net.corda.core.internal.times import net.corda.core.messaging.RPCOps import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.serialize @@ -29,6 +30,7 @@ import net.corda.core.utilities.debug import net.corda.core.utilities.getOrThrow import net.corda.nodeapi.RPCApi import net.corda.nodeapi.internal.DeduplicationChecker +import org.apache.activemq.artemis.api.core.ActiveMQException import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.* @@ -173,6 +175,8 @@ class RPCClientProxyHandler( private val deduplicationSequenceNumber = AtomicLong(0) private val sendingEnabled = AtomicBoolean(true) + // used to interrupt failover thread (i.e. client is closed while failing over) + private var haFailoverThread: Thread? = null /** * Start the client. This creates the per-client queue, starts the consumer session and the reaper. @@ -192,17 +196,22 @@ class RPCClientProxyHandler( rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) + // Create a session factory using the first available server. If more than one transport configuration was + // used when creating the server locator, every one will be tried during failover. The locator will round-robin + // through the available transport configurations with the starting position being generated randomly. + // If there's only one available, that one will be retried continuously as configured in rpcConfiguration. + // There is no failover on first attempt, meaning that if a connection cannot be established, the serverLocator + // will try another transport if it exists or throw an exception otherwise. sessionFactory = serverLocator.createSessionFactory() - producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) - consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) - rpcConsumer = consumerSession!!.createConsumer(clientAddress) - rpcConsumer!!.setMessageHandler(this::artemisMessageHandler) - producerSession!!.addFailoverListener(this::failoverHandler) + // Depending on how the client is constructed, connection failure is treated differently + if (serverLocator.staticTransportConfigurations.size == 1) { + sessionFactory!!.addFailoverListener(this::failoverHandler) + } else { + sessionFactory!!.addFailoverListener(this::haFailoverHandler) + } + initSessions() lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) - consumerSession!!.start() - producerSession!!.start() + startSessions() } // This is the general function that transforms a client side RPC to internal Artemis messages. @@ -341,6 +350,10 @@ class RPCClientProxyHandler( * @param notify whether to notify observables or not. */ private fun close(notify: Boolean = true) { + haFailoverThread?.apply { + interrupt() + join(1000) + } sessionFactory?.close() reaperScheduledFuture?.cancel(false) observableContext.observableMap.invalidateAll() @@ -403,26 +416,82 @@ class RPCClientProxyHandler( } } + private fun attemptReconnect() { + var reconnectAttempts = rpcConfiguration.maxReconnectAttempts * serverLocator.staticTransportConfigurations.size + var retryInterval = rpcConfiguration.connectionRetryInterval + val maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval + + var transportIterator = serverLocator.staticTransportConfigurations.iterator() + while (transportIterator.hasNext() && reconnectAttempts != 0) { + val transport = transportIterator.next() + if (!transportIterator.hasNext()) + transportIterator = serverLocator.staticTransportConfigurations.iterator() + + log.debug("Trying to connect using ${transport.params}") + try { + if (serverLocator != null && !serverLocator.isClosed) { + sessionFactory = serverLocator.createSessionFactory(transport) + } else { + log.warn("Stopping reconnect attempts.") + log.debug("Server locator is closed or garbage collected. Proxy may have been closed during reconnect.") + break + } + } catch (e: ActiveMQException) { + try { + Thread.sleep(retryInterval.toMillis()) + } catch (e: InterruptedException) {} + // could not connect, try with next server transport + reconnectAttempts-- + retryInterval = minOf(maxRetryInterval, retryInterval.times(rpcConfiguration.connectionRetryIntervalMultiplier.toLong())) + continue + } + + log.debug("Connected successfully using ${transport.params}") + log.info("RPC server available.") + sessionFactory!!.addFailoverListener(this::haFailoverHandler) + initSessions() + startSessions() + sendingEnabled.set(true) + break + } + + if (reconnectAttempts == 0 || sessionFactory == null) + log.error("Could not reconnect to the RPC server.") + } + + private fun initSessions() { + producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) + consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) + rpcConsumer = consumerSession!!.createConsumer(clientAddress) + rpcConsumer!!.setMessageHandler(this::artemisMessageHandler) + } + + private fun startSessions() { + consumerSession!!.start() + producerSession!!.start() + } + + private fun haFailoverHandler(event: FailoverEventType) { + if (event == FailoverEventType.FAILURE_DETECTED) { + log.warn("Connection failure. Attempting to reconnect using back-up addresses.") + cleanUpOnConnectionLoss() + sessionFactory?.apply { + connection.destroy() + cleanup() + close() + } + haFailoverThread = Thread.currentThread() + attemptReconnect() + } + /* Other events are not considered as reconnection is not done by Artemis */ + } + private fun failoverHandler(event: FailoverEventType) { when (event) { FailoverEventType.FAILURE_DETECTED -> { - sendingEnabled.set(false) - - log.warn("Terminating observables.") - val m = observableContext.observableMap.asMap() - m.keys.forEach { k -> - observationExecutorPool.run(k) { - m[k]?.onError(RPCException("Connection failure detected.")) - } - } - observableContext.observableMap.invalidateAll() - - rpcReplyMap.forEach { _, replyFuture -> - replyFuture.setException(RPCException("Connection failure detected.")) - } - - rpcReplyMap.clear() - callSiteMap?.clear() + cleanUpOnConnectionLoss() } FailoverEventType.FAILOVER_COMPLETED -> { @@ -435,6 +504,25 @@ class RPCClientProxyHandler( } } } + + private fun cleanUpOnConnectionLoss() { + sendingEnabled.set(false) + log.warn("Terminating observables.") + val m = observableContext.observableMap.asMap() + m.keys.forEach { k -> + observationExecutorPool.run(k) { + m[k]?.onError(RPCException("Connection failure detected.")) + } + } + observableContext.observableMap.invalidateAll() + + rpcReplyMap.forEach { _, replyFuture -> + replyFuture.setException(RPCException("Connection failure detected.")) + } + + rpcReplyMap.clear() + callSiteMap?.clear() + } } private typealias RpcObservableMap = Cache>> diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt index 042a38ce67..5e04f66414 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt @@ -75,6 +75,13 @@ inline fun RPCDriverDSL.startRpcClient( configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default ) = startRpcClient(I::class.java, rpcAddress, username, password, configuration) +inline fun RPCDriverDSL.startRpcClient( + haAddressPool: List, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default +) = startRpcClient(I::class.java, haAddressPool, username, password, configuration) + data class RpcBrokerHandle( val hostAndPort: NetworkHostAndPort?, /** null if this is an InVM broker */ @@ -336,6 +343,32 @@ data class RPCDriverDSL( } } + /** + * Starts a Netty RPC client. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param haAddressPool The addresses of the RPC servers(configured in HA mode) to connect to. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + * @param configuration The RPC client configuration. + */ + fun startRpcClient( + rpcOpsClass: Class, + haAddressPool: List, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default + ): CordaFuture { + return driverDSL.executorService.fork { + val client = RPCClient(haAddressPool, null, configuration) + val connection = client.start(rpcOpsClass, username, password, externalTrace) + driverDSL.shutdownManager.registerShutdown { + connection.close() + } + connection.proxy + } + } + /** * Starts a Netty RPC client in a new JVM process that calls random RPCs with random arguments. * From be083d6763da463dbc5be6b7e81b9d6df752f5e7 Mon Sep 17 00:00:00 2001 From: bpaunescu Date: Sun, 22 Apr 2018 15:04:19 +0100 Subject: [PATCH 2/9] Added helper method for creating tcp transports from a list of host:port --- .../client/rpc/internal/RPCClientProxyHandler.kt | 2 +- .../net/corda/nodeapi/ArtemisTcpTransport.kt | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index 8852a096cd..458bea35ef 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -417,7 +417,7 @@ class RPCClientProxyHandler( } private fun attemptReconnect() { - var reconnectAttempts = rpcConfiguration.maxReconnectAttempts * serverLocator.staticTransportConfigurations.size + var reconnectAttempts = rpcConfiguration.maxReconnectAttempts.times(serverLocator.staticTransportConfigurations.size) var retryInterval = rpcConfiguration.connectionRetryInterval val maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt index a9438611a4..40333af921 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt @@ -97,5 +97,19 @@ class ArtemisTcpTransport { } return TransportConfiguration(factoryName, options) } + + /** Create as list of [TransportConfiguration]. **/ + fun tcpTransportsFromList( + direction: ConnectionDirection, + hostAndPortList: List, + config: SSLConfiguration?, + enableSSL: Boolean = true): List{ + val tcpTransports = ArrayList(hostAndPortList.size) + hostAndPortList.forEach { + tcpTransports.add(tcpTransport(direction, it, config, enableSSL)) + } + + return tcpTransports + } } } From e51878417b81dd882cc846d3d8f9aa977752e51b Mon Sep 17 00:00:00 2001 From: bpaunescu Date: Mon, 23 Apr 2018 11:20:08 +0100 Subject: [PATCH 3/9] Address PR comments --- .../net/corda/client/rpc/RPCStabilityTests.kt | 2 +- .../net/corda/client/rpc/CordaRPCClient.kt | 20 ++----------------- .../rpc/internal/RPCClientProxyHandler.kt | 14 +++++++++---- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index d9602f0c43..ada3e1e37c 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -388,7 +388,7 @@ class RPCStabilityTests { servers[response]!!.shutdown() servers.remove(response) - //failover will take some time + // Failover will take some time. while (true) { try { response = client.serverId() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index a378808068..cf82f270c9 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -93,7 +93,8 @@ interface CordaRPCClientConfiguration { * [CordaRPCClientConfiguration]. While attempting failover, current and future RPC calls will throw * [RPCException] and previously returned observables will call onError(). * - * If the client was created using a list of hosts, automatic failover will occur(the servers have to be started in HA mode) + * If the client was created using a list of hosts, automatic failover will occur (the servers have to be started in + * HA mode). * * @param hostAndPort The network address to connect to. * @param configuration An optional configuration used to tweak client behaviour. @@ -130,14 +131,6 @@ class CordaRPCClient private constructor( return CordaRPCClient(hostAndPort, configuration, sslConfiguration) } - internal fun createWithSsl( - haAddressPool: List, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), - sslConfiguration: SSLConfiguration? = null - ): CordaRPCClient { - return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, null, haAddressPool) - } - internal fun createWithSslAndClassLoader( hostAndPort: NetworkHostAndPort, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), @@ -146,15 +139,6 @@ class CordaRPCClient private constructor( ): CordaRPCClient { return CordaRPCClient(hostAndPort, configuration, sslConfiguration, classLoader) } - - internal fun createWithSslAndClassLoader( - haAddressPool: List, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), - sslConfiguration: SSLConfiguration? = null, - classLoader: ClassLoader? = null - ): CordaRPCClient { - return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, classLoader, haAddressPool) - } } init { diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index 458bea35ef..2ddb44cbad 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -70,6 +70,12 @@ import kotlin.reflect.jvm.javaMethod * unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually * automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s. * The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor]. + * + * The client will attempt to failover in case the server become unreachable. Depending on the [ServerLocataor] instance + * passed in the constructor, failover is either handle at Artemis level or client level. If only one transport + * was used to create the [ServerLocator], failover is handled by Artemis (retrying based on [CordaRPCClientConfiguration]. + * If a list of transport configurations was used, failover is handled locally. Artemis is able to do it, however the + * brokers on server side need to be configured in HA mode and the [ServerLocator] needs to be created with HA as well. */ class RPCClientProxyHandler( private val rpcConfiguration: CordaRPCClientConfiguration, @@ -175,7 +181,7 @@ class RPCClientProxyHandler( private val deduplicationSequenceNumber = AtomicLong(0) private val sendingEnabled = AtomicBoolean(true) - // used to interrupt failover thread (i.e. client is closed while failing over) + // Used to interrupt failover thread (i.e. client is closed while failing over). private var haFailoverThread: Thread? = null /** @@ -440,13 +446,13 @@ class RPCClientProxyHandler( try { Thread.sleep(retryInterval.toMillis()) } catch (e: InterruptedException) {} - // could not connect, try with next server transport + // Could not connect, try with next server transport. reconnectAttempts-- retryInterval = minOf(maxRetryInterval, retryInterval.times(rpcConfiguration.connectionRetryIntervalMultiplier.toLong())) continue } - log.debug("Connected successfully using ${transport.params}") + log.debug("Connected successfully after $reconnectAttempts attempts using ${transport.params}.") log.info("RPC server available.") sessionFactory!!.addFailoverListener(this::haFailoverHandler) initSessions() @@ -485,7 +491,7 @@ class RPCClientProxyHandler( haFailoverThread = Thread.currentThread() attemptReconnect() } - /* Other events are not considered as reconnection is not done by Artemis */ + // Other events are not considered as reconnection is not done by Artemis. } private fun failoverHandler(event: FailoverEventType) { From ce5fb662605e100e129c08a526007b10c8736cfa Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 13 Apr 2018 15:17:20 +0100 Subject: [PATCH 4/9] StateMachine rewrite --- build.gradle | 1 + constants.properties | 1 + .../core/flows/IdentifiableException.java | 16 + .../net/corda/core/flows/FlowException.kt | 24 +- .../kotlin/net/corda/core/flows/FlowLogic.kt | 171 ++--- .../corda/core/internal/FlowAsyncOperation.kt | 23 + .../net/corda/core/internal/FlowIORequest.kt | 89 +++ .../corda/core/internal/FlowStateMachine.kt | 44 +- .../core/node/services/TransactionStorage.kt | 6 + .../corda/core/utilities/UntrustworthyData.kt | 15 + .../net/corda/core/flows/FlowsInJavaTest.java | 3 +- docs/source/api-flows.rst | 62 -- .../java/net/corda/docs/FlowCookbookJava.java | 7 - .../kotlin/net/corda/docs/FlowCookbook.kt | 7 - .../net/corda/docs/LaunchSpaceshipFlow.kt | 99 --- node/build.gradle | 2 +- .../services/messaging/P2PMessagingTest.kt | 9 +- .../net/corda/node/internal/AbstractNode.kt | 12 +- .../node/services/api/CheckpointStorage.kt | 34 +- .../node/services/api/ServiceHubInternal.kt | 4 +- .../services/events/NodeSchedulerService.kt | 83 ++- .../node/services/messaging/Messaging.kt | 49 +- .../services/messaging/MessagingExecutor.kt | 86 +++ .../messaging/P2PMessageDeduplicator.kt | 108 +++ .../services/messaging/P2PMessagingClient.kt | 280 +++---- .../persistence/DBCheckpointStorage.kt | 32 +- .../persistence/DBTransactionStorage.kt | 21 +- .../node/services/schema/NodeSchemaService.kt | 3 +- .../node/services/statemachine/Action.kt | 132 ++++ .../services/statemachine/ActionExecutor.kt | 14 + .../statemachine/ActionExecutorImpl.kt | 228 ++++++ .../services/statemachine/CountUpDownLatch.kt | 66 ++ .../services/statemachine/DeduplicationId.kt | 47 ++ .../corda/node/services/statemachine/Event.kt | 126 ++++ .../node/services/statemachine/FlowFiber.kt | 18 + .../services/statemachine/FlowHospital.kt | 18 + .../services/statemachine/FlowIORequest.kt | 121 --- .../statemachine/FlowLogicRefFactoryImpl.kt | 2 +- .../services/statemachine/FlowMessaging.kt | 91 +++ .../services/statemachine/FlowSessionImpl.kt | 60 +- .../statemachine/FlowSessionInternal.kt | 66 -- .../statemachine/FlowStateMachineImpl.kt | 657 ++++++----------- .../statemachine/PropagatingFlowHospital.kt | 20 + .../statemachine/SessionRejectException.kt | 8 + .../SingleThreadedStateMachineManager.kt | 689 ++++++++++++++++++ .../statemachine/StateMachineManager.kt | 40 +- .../statemachine/StateMachineManagerImpl.kt | 666 ----------------- .../statemachine/StateMachineState.kt | 231 ++++++ .../node/services/statemachine/SubFlow.kt | 74 ++ .../statemachine/TransitionExecutor.kt | 25 + .../statemachine/TransitionExecutorImpl.kt | 67 ++ .../DumpHistoryOnErrorInterceptor.kt | 51 ++ ...FiberDeserializationCheckingInterceptor.kt | 95 +++ .../interceptors/HospitalisingInterceptor.kt | 46 ++ .../interceptors/MetricInterceptor.kt | 24 + .../interceptors/PrintingInterceptor.kt | 31 + .../TransitionDiagnosticRecord.kt | 51 ++ .../DeliverSessionMessageTransition.kt | 186 +++++ .../transitions/DoRemainingWorkTransition.kt | 37 + .../transitions/ErrorFlowTransition.kt | 124 ++++ .../transitions/StartedFlowTransition.kt | 410 +++++++++++ .../statemachine/transitions/StateMachine.kt | 30 + .../transitions/TopLevelTransition.kt | 243 ++++++ .../statemachine/transitions/Transition.kt | 32 + .../transitions/TransitionBuilder.kt | 74 ++ .../transitions/TransitionResult.kt | 46 ++ .../transitions/UnstartedFlowTransition.kt | 80 ++ .../services/vault/VaultSoftLockManager.kt | 9 +- .../net/corda/node/utilities/ObjectDiffer.kt | 144 ++++ .../node/messaging/InMemoryMessagingTests.kt | 9 +- .../node/messaging/TwoPartyTradeFlowTests.kt | 6 + .../events/NodeSchedulerServiceTest.kt | 17 +- .../messaging/ArtemisMessagingTest.kt | 4 +- .../persistence/DBCheckpointStorageTests.kt | 80 +- .../statemachine/FlowFrameworkTests.kt | 73 +- .../transactions/MaxTransactionSizeTests.kt | 8 +- .../transactions/NotaryServiceTests.kt | 2 +- samples/irs-demo/build.gradle | 1 + samples/network-visualiser/build.gradle | 1 + .../testing/node/InMemoryMessagingNetwork.kt | 36 +- .../testing/node/internal/InMemoryMessage.kt | 9 +- .../node/internal/InternalMockNetwork.kt | 3 +- .../node/internal/MockTransactionStorage.kt | 9 +- 83 files changed, 4719 insertions(+), 2009 deletions(-) create mode 100644 core/src/main/java/net/corda/core/flows/IdentifiableException.java create mode 100644 core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt create mode 100644 core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt delete mode 100644 docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt delete mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt delete mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt delete mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/MetricInterceptor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt create mode 100644 node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt diff --git a/build.gradle b/build.gradle index 99667f4a73..fb619d1a17 100644 --- a/build.gradle +++ b/build.gradle @@ -47,6 +47,7 @@ buildscript { ext.bouncycastle_version = constants.getProperty("bouncycastleVersion") ext.guava_version = constants.getProperty("guavaVersion") ext.caffeine_version = constants.getProperty("caffeineVersion") + ext.metrics_version = constants.getProperty("metricsVersion") ext.okhttp_version = '3.5.0' ext.netty_version = '4.1.9.Final' ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion") diff --git a/constants.properties b/constants.properties index 7be4b2da09..49a76da863 100644 --- a/constants.properties +++ b/constants.properties @@ -8,3 +8,4 @@ jsr305Version=3.0.2 artifactoryPluginVersion=4.4.18 snakeYamlVersion=1.19 caffeineVersion=2.6.2 +metricsVersion=3.2.5 diff --git a/core/src/main/java/net/corda/core/flows/IdentifiableException.java b/core/src/main/java/net/corda/core/flows/IdentifiableException.java new file mode 100644 index 0000000000..d1d32a97f3 --- /dev/null +++ b/core/src/main/java/net/corda/core/flows/IdentifiableException.java @@ -0,0 +1,16 @@ +package net.corda.core.flows; + +import javax.annotation.Nullable; + +/** + * An exception that may be identified with an ID. If an exception originates in a counter-flow this ID will be + * propagated. This allows correlation of error conditions across different flows. + */ +public interface IdentifiableException { + /** + * @return the ID of the error, or null if the error doesn't have it set (yet). + */ + default @Nullable Long getErrorId() { + return null; + } +} diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt index 33251020f8..ac0fbdaa23 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt @@ -7,16 +7,27 @@ import net.corda.core.CordaRuntimeException /** * Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end. * The exception will propagate to all counterparty flows and will be thrown on their end the next time they wait on a - * [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already ended, - * will not receive the exception (if this is required then have them wait for a confirmation message). + * [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already + * ended, will not receive the exception (if this is required then have them wait for a confirmation message). + * + * If the *rethrown* [FlowException] is uncaught in counterparty flows and propagation triggers then the exception is + * downgraded to an [UnexpectedFlowEndException]. This means only immediate counterparty flows will receive information + * about what the exception was. * * [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service. * It is recommended a [FlowLogic] document the [FlowException] types it can throw. + * + * @property originalErrorId the ID backing [getErrorId]. If null it will be set dynamically by the flow framework when + * the exception is handled. This ID is propagated to counterparty flows, even when the [FlowException] is + * downgraded to an [UnexpectedFlowEndException]. This is so the error conditions may be correlated later on. */ -open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) { +open class FlowException(message: String?, cause: Throwable?) : + CordaException(message, cause), IdentifiableException { constructor(message: String?) : this(message, null) constructor(cause: Throwable?) : this(cause?.toString(), cause) constructor() : this(null, null) + var originalErrorId: Long? = null + override fun getErrorId(): Long? = originalErrorId } // DOCEND 1 @@ -25,6 +36,7 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m * that we were not expecting), or the other side had an internal error, or the other side terminated when we * were waiting for a response. */ -class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { - constructor(msg: String) : this(msg, null) -} \ No newline at end of file +class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long) : + CordaRuntimeException(message, cause), IdentifiableException { + override fun getErrorId(): Long = originalErrorId +} diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index c0a369a730..5050b6253a 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -6,20 +6,17 @@ import net.corda.core.CordaInternal import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.FlowStateMachine -import net.corda.core.internal.abbreviate -import net.corda.core.internal.uncheckedCast +import net.corda.core.internal.* import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.serialize import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.ProgressTracker -import net.corda.core.utilities.UntrustworthyData -import net.corda.core.utilities.debug +import net.corda.core.utilities.* import org.slf4j.Logger import java.time.Duration -import java.time.Instant /** * A sub-class of [FlowLogic] implements a flow using direct, straight line blocking code. Thus you @@ -77,12 +74,19 @@ abstract class FlowLogic { */ @Suspendable @JvmStatic + @JvmOverloads @Throws(FlowException::class) - fun sleep(duration: Duration) { + fun sleep(duration: Duration, maySkipCheckpoint: Boolean = false) { if (duration > Duration.ofMinutes(5)) { throw FlowException("Attempt to sleep for longer than 5 minutes is not supported. Consider using SchedulableState.") } - (Strand.currentStrand() as? FlowStateMachine<*>)?.sleepUntil(Instant.now() + duration) ?: Strand.sleep(duration.toMillis()) + val fiber = (Strand.currentStrand() as? FlowStateMachine<*>) + if (fiber == null) { + Strand.sleep(duration.toMillis()) + } else { + val request = FlowIORequest.Sleep(wakeUpAfter = fiber.serviceHub.clock.instant() + duration) + fiber.suspend(request, maySkipCheckpoint = maySkipCheckpoint) + } } } @@ -94,7 +98,7 @@ abstract class FlowLogic { /** * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is - * only available once the flow has started, which means it cannnot be accessed in the constructor. Either + * only available once the flow has started, which means it cannot be accessed in the constructor. Either * access this lazily or from inside [call]. */ val serviceHub: ServiceHub get() = stateMachine.serviceHub @@ -104,7 +108,7 @@ abstract class FlowLogic { * that this function does not communicate in itself, the counter-flow will be kicked off by the first send/receive. */ @Suspendable - fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions) + fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party) /** * Specifies the identity, with certificate, to use for this flow. This will be one of the multiple identities that @@ -114,7 +118,10 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentityAndCert: PartyAndCertificate get() = stateMachine.ourIdentityAndCert + val ourIdentityAndCert: PartyAndCertificate get() { + return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity } + ?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!") + } /** * Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node. @@ -124,102 +131,23 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentity: Party get() = ourIdentityAndCert.party - /** - * Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it - * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. - * - * This method can be called before any send or receive has been done with [otherParty]. In such a case this will force - * them to start their flow. - */ - @Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING) - @Suspendable - fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false) - - /** - * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response - * is received, which must be of the given [R] type. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - * - * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to - * use this when you expect to do a message swap than do use [send] and then [receive] in turn. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) - inline fun sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData { - return sendAndReceive(R::class.java, otherParty, payload) - } - - /** - * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response - * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data - * should not be trusted until it's been thoroughly verified for consistency and that all expectations are - * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. - * - * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to - * use this when you expect to do a message swap than do use [send] and then [receive] in turn. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) - @Suspendable - open fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false) - } - - /** - * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. - * - * Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead. - * It is only intended for the case where the [otherParty] is running a distributed service with an idempotent - * flow which only accepts a single request and sends back a single response – e.g. a notary or certain types of - * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered - * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. - */ - @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) - internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) - } + val ourIdentity: Party get() = stateMachine.ourIdentity @Suspendable internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) + val request = FlowIORequest.SendAndReceive( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = true + ) + return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType) } @Suspendable internal inline fun FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) + return sendAndReceiveWithRetry(R::class.java, payload) } - /** - * Suspends until the specified [otherParty] sends us a message of type [R]. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - */ - @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) - inline fun receive(otherParty: Party): UntrustworthyData = receive(R::class.java, otherParty) - - /** - * Suspends until the specified [otherParty] sends us a message of type [receiveType]. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) - @Suspendable - open fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { - return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false) - } /** Suspends until a message has been received for each session in the specified [sessions]. * @@ -232,8 +160,14 @@ abstract class FlowLogic { * @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them. */ @Suspendable - open fun receiveAllMap(sessions: Map>): Map> { - return stateMachine.receiveAll(sessions, this) + @JvmOverloads + open fun receiveAllMap(sessions: Map>, maySkipCheckpoint: Boolean = false): Map> { + enforceNoPrimitiveInReceive(sessions.values) + val replies = stateMachine.suspend( + ioRequest = FlowIORequest.Receive(sessions.keys.toNonEmptySet()), + maySkipCheckpoint = maySkipCheckpoint + ) + return replies.mapValues { (session, payload) -> payload.checkPayloadIs(sessions[session]!!) } } /** @@ -248,24 +182,13 @@ abstract class FlowLogic { * @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. */ @Suspendable - open fun receiveAll(receiveType: Class, sessions: List): List> { + @JvmOverloads + open fun receiveAll(receiveType: Class, sessions: List, maySkipCheckpoint: Boolean = false): List> { + enforceNoPrimitiveInReceive(listOf(receiveType)) enforceNoDuplicates(sessions) return castMapValuesToKnownType(receiveAllMap(associateSessionsToReceiveType(receiveType, sessions))) } - /** - * Queues the given [payload] for sending to the [otherParty] and continues without suspending. - * - * Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty] - * is offline then message delivery will be retried until it comes back or until the message is older than the - * network's event horizon time. - */ - @Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING) - @Suspendable - open fun send(otherParty: Party, payload: Any) { - stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false) - } - /** * Invokes the given subflow. This function returns once the subflow completes successfully with the result * returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the @@ -283,11 +206,8 @@ abstract class FlowLogic { open fun subFlow(subLogic: FlowLogic): R { subLogic.stateMachine = stateMachine maybeWireUpProgressTracking(subLogic) - if (!subLogic.javaClass.isAnnotationPresent(InitiatingFlow::class.java)) { - subLogic.flowUsedForSessions = flowUsedForSessions - } logger.debug { "Calling subflow: $subLogic" } - val result = subLogic.call() + val result = stateMachine.subFlow(subLogic) logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" } // It's easy to forget this when writing flows so we just step it to the DONE state when it completes. subLogic.progressTracker?.currentStep = ProgressTracker.DONE @@ -384,7 +304,8 @@ abstract class FlowLogic { @Suspendable @JvmOverloads fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction { - return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint) + val request = FlowIORequest.WaitForLedgerCommit(hash) + return stateMachine.suspend(request, maySkipCheckpoint = maySkipCheckpoint) } /** @@ -427,11 +348,6 @@ abstract class FlowLogic { _stateMachine = value } - // This is the flow used for managing sessions. It defaults to the current flow but if this is an inlined sub-flow - // then it will point to the flow it's been inlined to. - @Suppress("LeakingThis") - private var flowUsedForSessions: FlowLogic<*> = this - private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) { val ours = progressTracker val theirs = subLogic.progressTracker @@ -448,6 +364,11 @@ abstract class FlowLogic { require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." } } + private fun enforceNoPrimitiveInReceive(types: Collection>) { + val primitiveTypes = types.filter { it.isPrimitive } + require(primitiveTypes.isEmpty()) { "Cannot receive primitive type(s) $primitiveTypes" } + } + private fun associateSessionsToReceiveType(receiveType: Class, sessions: List): Map> { return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType }) } @@ -472,4 +393,4 @@ data class FlowInfo( * to deduplicate it from other releases of the same CorDapp, typically a version string. See the * [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details. */ - val appName: String) \ No newline at end of file + val appName: String) diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt b/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt new file mode 100644 index 0000000000..dad14b4b69 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt @@ -0,0 +1,23 @@ +package net.corda.core.internal + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.flows.FlowLogic +import net.corda.core.serialization.CordaSerializable + +/** + * Interface for arbitrary operations that can be invoked in a flow asynchronously - the flow will suspend until the + * operation completes. Operation parameters are expected to be injected via constructor. + */ +@CordaSerializable +interface FlowAsyncOperation { + /** Performs the operation in a non-blocking fashion. */ + fun execute(): CordaFuture +} + +/** Executes the specified [operation] and suspends until operation completion. */ +@Suspendable +fun FlowLogic.executeAsync(operation: FlowAsyncOperation, maySkipCheckpoint: Boolean = false): R { + val request = FlowIORequest.ExecuteAsyncOperation(operation) + return stateMachine.suspend(request, maySkipCheckpoint) +} diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt new file mode 100644 index 0000000000..55ef39aecf --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt @@ -0,0 +1,89 @@ +package net.corda.core.internal + +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowSession +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.NonEmptySet +import java.time.Instant + +/** + * A [FlowIORequest] represents an IO request of a flow when it suspends. It is persisted in checkpoints. + */ +sealed class FlowIORequest { + /** + * Send messages to sessions. + * + * @property sessionToMessage a map from session to message-to-be-sent. + * @property shouldRetrySend specifies whether the send should be retried. + */ + data class Send( + val sessionToMessage: Map>, + val shouldRetrySend: Boolean + ) : FlowIORequest() { + override fun toString() = "Send(" + + "sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " + + "shouldRetrySend=$shouldRetrySend" + + ")" + } + + /** + * Receive messages from sessions. + * + * @property sessions the sessions to receive messages from. + * @return a map from session to received message. + */ + data class Receive( + val sessions: NonEmptySet + ) : FlowIORequest>>() + + /** + * Send and receive messages from the specified sessions. + * + * @property sessionToMessage a map from session to message-to-be-sent. The keys also specify which sessions to + * receive from. + * @property shouldRetrySend specifies whether the send should be retried. + * @return a map from session to received message. + */ + data class SendAndReceive( + val sessionToMessage: Map>, + val shouldRetrySend: Boolean + ) : FlowIORequest>>() { + override fun toString() = "SendAndReceive(${sessionToMessage.mapValues { (key, value) -> + "$key=${value.hash}" }}, shouldRetrySend=$shouldRetrySend)" + } + + /** + * Wait for a transaction to be committed to the database. + * + * @property hash the hash of the transaction. + * @return the committed transaction. + */ + data class WaitForLedgerCommit(val hash: SecureHash) : FlowIORequest() + + /** + * Get the FlowInfo of the specified sessions. + * + * @property sessions the sessions to get the FlowInfo of. + * @return a map from session to FlowInfo. + */ + data class GetFlowInfo(val sessions: NonEmptySet) : FlowIORequest>() + + /** + * Suspend the flow until the specified time. + * + * @property wakeUpAfter the time to sleep until. + */ + data class Sleep(val wakeUpAfter: Instant) : FlowIORequest() + + /** + * Suspend the flow until all Initiating sessions are confirmed. + */ + object WaitForSessionConfirmations : FlowIORequest() + + /** + * Execute the specified [operation], suspend the flow until completion. + */ + data class ExecuteAsyncOperation(val operation: FlowAsyncOperation) : FlowIORequest() +} diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index 3d9af23d71..7ea31b7bb7 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -1,64 +1,42 @@ package net.corda.core.internal import co.paralleluniverse.fibers.Suspendable +import net.corda.core.DoNotImplement import net.corda.core.concurrent.CordaFuture -import net.corda.core.crypto.SecureHash import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.context.InvocationContext import net.corda.core.node.ServiceHub -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.UntrustworthyData import org.slf4j.Logger -import java.time.Instant /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ -interface FlowStateMachine { +@DoNotImplement +interface FlowStateMachine { @Suspendable - fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo + fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): SUSPENDRETURN @Suspendable - fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession - - @Suspendable - fun sendAndReceive(receiveType: Class, - otherParty: Party, - payload: Any, - sessionFlow: FlowLogic<*>, - retrySend: Boolean, - maySkipCheckpoint: Boolean): UntrustworthyData - - @Suspendable - fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData - - @Suspendable - fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) - - @Suspendable - fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction - - @Suspendable - fun sleepUntil(until: Instant) + fun initiateFlow(party: Party): FlowSession fun checkFlowPermission(permissionName: String, extraAuditData: Map) fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) + @Suspendable + fun subFlow(subFlow: FlowLogic): SUBFLOWRETURN + @Suspendable fun flowStackSnapshot(flowClass: Class>): FlowStackSnapshot? @Suspendable fun persistFlowStackSnapshot(flowClass: Class>) - val logic: FlowLogic + val logic: FlowLogic val serviceHub: ServiceHub val logger: Logger val id: StateMachineRunId - val resultFuture: CordaFuture + val resultFuture: CordaFuture val context: InvocationContext - val ourIdentityAndCert: PartyAndCertificate - - @Suspendable - fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> + val ourIdentity: Party } diff --git a/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt b/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt index 9b6b713ed2..b04c96729f 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt @@ -1,6 +1,7 @@ package net.corda.core.node.services import net.corda.core.DoNotImplement +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.messaging.DataFeed import net.corda.core.transactions.SignedTransaction @@ -26,4 +27,9 @@ interface TransactionStorage { * Returns all currently stored transactions and further fresh ones. */ fun track(): DataFeed, SignedTransaction> + + /** + * Returns a future that completes with the transaction corresponding to [id] once it has been committed + */ + fun trackTransaction(id: SecureHash): CordaFuture } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt b/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt index 272b5ec200..dcc8ba7e00 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt @@ -2,6 +2,9 @@ package net.corda.core.utilities import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.FlowException +import net.corda.core.internal.castIfPossible +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes import java.io.Serializable /** @@ -29,3 +32,15 @@ class UntrustworthyData(@PublishedApi internal val fromUntrustedWorld: T) } inline fun UntrustworthyData.unwrap(validator: (T) -> R): R = validator(fromUntrustedWorld) + +fun SerializedBytes.checkPayloadIs(type: Class): UntrustworthyData { + val payloadData: T = try { + val serializer = SerializationDefaults.SERIALIZATION_FACTORY + serializer.deserialize(this, type, SerializationDefaults.P2P_CONTEXT) + } catch (ex: Exception) { + throw IllegalArgumentException("Payload invalid", ex) + } + return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: + throw IllegalArgumentException("We were expecting a ${type.name} but we instead got a " + + "${payloadData.javaClass.name} (${payloadData})") +} diff --git a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java index aa616b3ca0..61bb135f56 100644 --- a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java +++ b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java @@ -61,9 +61,8 @@ public class FlowsInJavaTest { fail("ExecutionException should have been thrown"); } catch (ExecutionException e) { assertThat(e.getCause()) - .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("primitive") - .hasMessageContaining(receiveType.getName()); + .hasMessageContaining(Primitives.unwrap(receiveType).getName()); } } diff --git a/docs/source/api-flows.rst b/docs/source/api-flows.rst index 8f00b6bc0c..6e80c7ae72 100644 --- a/docs/source/api-flows.rst +++ b/docs/source/api-flows.rst @@ -416,68 +416,6 @@ Our side of the flow must mirror these calls. We could do this as follows: :end-before: DOCEND 08 :dedent: 12 -Why sessions? -^^^^^^^^^^^^^ - -Before ``FlowSession`` s were introduced the send/receive API looked a bit different. They were functions on -``FlowLogic`` and took the address ``Party`` as argument. The platform internally maintained a mapping from ``Party`` to -session, hiding sessions from the user completely. - -Although this is a convenient API it introduces subtle issues where a message that was originally meant for a specific -session may end up in another. - -Consider the following contrived example using the old ``Party`` based API: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt - :language: kotlin - :start-after: DOCSTART LaunchSpaceshipFlow - :end-before: DOCEND LaunchSpaceshipFlow - -The intention of the flows is very clear: LaunchSpaceshipFlow asks the president whether a spaceship should be launched. -It is expecting a boolean reply. The president in return first tells the secretary that they need coffee, which is also -communicated with a boolean. Afterwards the president replies to the launcher that they don't want to launch. - -However the above can go horribly wrong when the ``launcher`` happens to be the same party ``getSecretary`` returns. In -this case the boolean meant for the secretary will be received by the launcher! - -This indicates that ``Party`` is not a good identifier for the communication sequence, and indeed the ``Party`` based -API may introduce ways for an attacker to fish for information and even trigger unintended control flow like in the -above case. - -Hence we introduced ``FlowSession``, which identifies the communication sequence. With ``FlowSession`` s the above set -of flows would look like this: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt - :language: kotlin - :start-after: DOCSTART LaunchSpaceshipFlowCorrect - :end-before: DOCEND LaunchSpaceshipFlowCorrect - -Note how the president is now explicit about which session it wants to send to. - -Porting from the old Party-based API -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the old API the first ``send`` or ``receive`` to a ``Party`` was the one kicking off the counter-flow. This is now -explicit in the ``initiateFlow`` function call. To port existing code: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt - :language: kotlin - :start-after: DOCSTART FlowSession porting - :end-before: DOCEND FlowSession porting - :dedent: 8 - - .. literalinclude:: ../../docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java - :language: java - :start-after: DOCSTART FlowSession porting - :end-before: DOCEND FlowSession porting - :dedent: 12 - Subflows -------- Subflows are pieces of reusable flows that may be run by calling ``FlowLogic.subFlow``. There are two broad categories diff --git a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java index bf2c58b858..d17615841d 100644 --- a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java +++ b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java @@ -582,13 +582,6 @@ public class FlowCookbookJava { SignedTransaction notarisedTx2 = subFlow(new FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())); // DOCEND 10 - // DOCSTART FlowSession porting - send(regulator, new Object()); // Old API - // becomes - FlowSession session = initiateFlow(regulator); - session.send(new Object()); - // DOCEND FlowSession porting - return null; } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt index 0528caeaf3..880570e2df 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt @@ -572,13 +572,6 @@ class InitiatorFlow(val arg1: Boolean, val arg2: Int, private val counterparty: val additionalParties: Set = setOf(regulator) val notarisedTx2: SignedTransaction = subFlow(FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())) // DOCEND 10 - - // DOCSTART FlowSession porting - send(regulator, Any()) // Old API - // becomes - val session = initiateFlow(regulator) - session.send(Any()) - // DOCEND FlowSession porting } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt deleted file mode 100644 index e6826fa213..0000000000 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt +++ /dev/null @@ -1,99 +0,0 @@ -package net.corda.docs - -import co.paralleluniverse.fibers.Suspendable -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSession -import net.corda.core.flows.InitiatedBy -import net.corda.core.flows.InitiatingFlow -import net.corda.core.identity.Party -import net.corda.core.utilities.unwrap - -// DOCSTART LaunchSpaceshipFlow -@InitiatingFlow -class LaunchSpaceshipFlow : FlowLogic() { - @Suspendable - override fun call() { - val shouldLaunchSpaceship = receive(getPresident()).unwrap { it } - if (shouldLaunchSpaceship) { - launchSpaceship() - } - } - - fun launchSpaceship() { - } - - fun getPresident(): Party { - TODO() - } -} - -@InitiatedBy(LaunchSpaceshipFlow::class) -@InitiatingFlow -class PresidentSpaceshipFlow(val launcher: Party) : FlowLogic() { - @Suspendable - override fun call() { - val needCoffee = true - send(getSecretary(), needCoffee) - val shouldLaunchSpaceship = false - send(launcher, shouldLaunchSpaceship) - } - - fun getSecretary(): Party { - TODO() - } -} - -@InitiatedBy(PresidentSpaceshipFlow::class) -class SecretaryFlow(val president: Party) : FlowLogic() { - @Suspendable - override fun call() { - // ignore - } -} -// DOCEND LaunchSpaceshipFlow - -// DOCSTART LaunchSpaceshipFlowCorrect -@InitiatingFlow -class LaunchSpaceshipFlowCorrect : FlowLogic() { - @Suspendable - override fun call() { - val presidentSession = initiateFlow(getPresident()) - val shouldLaunchSpaceship = presidentSession.receive().unwrap { it } - if (shouldLaunchSpaceship) { - launchSpaceship() - } - } - - fun launchSpaceship() { - } - - fun getPresident(): Party { - TODO() - } -} - -@InitiatedBy(LaunchSpaceshipFlowCorrect::class) -@InitiatingFlow -class PresidentSpaceshipFlowCorrect(val launcherSession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - val needCoffee = true - val secretarySession = initiateFlow(getSecretary()) - secretarySession.send(needCoffee) - val shouldLaunchSpaceship = false - launcherSession.send(shouldLaunchSpaceship) - } - - fun getSecretary(): Party { - TODO() - } -} - -@InitiatedBy(PresidentSpaceshipFlowCorrect::class) -class SecretaryFlowCorrect(val presidentSession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - // ignore - } -} -// DOCEND LaunchSpaceshipFlowCorrect diff --git a/node/build.gradle b/node/build.gradle index 2158c98b2c..a5b169d690 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -107,7 +107,7 @@ dependencies { } // Coda Hale's Metrics: for monitoring of key statistics - compile "io.dropwizard.metrics:metrics-core:3.1.2" + compile "io.dropwizard.metrics:metrics-core:$metrics_version" // JimFS: in memory java.nio filesystem. Used for test and simulation utilities. compile "com.google.jimfs:jimfs:1.1" diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt index b1729281b1..35dffc045c 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt @@ -144,7 +144,7 @@ class P2PMessagingTest { distributedServiceNodes.forEach { val nodeName = it.services.myInfo.legalIdentitiesAndCerts.first().name - it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _ -> + it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler -> crashingNodes.requestsReceived.incrementAndGet() crashingNodes.firstRequestReceived.countDown() // The node which receives the first request will ignore all requests @@ -159,6 +159,7 @@ class P2PMessagingTest { val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes) it.internalServices.networkService.send(response, request.replyTo) } + handler.afterDatabaseTransaction() } } return crashingNodes @@ -186,10 +187,11 @@ class P2PMessagingTest { } private fun InProcess.respondWith(message: Any) { - internalServices.networkService.addMessageHandler("test.request") { netMessage, _ -> + internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler -> val request = netMessage.data.deserialize() val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes) internalServices.networkService.send(response, request.replyTo) + handler.afterDatabaseTransaction() } } @@ -211,11 +213,12 @@ class P2PMessagingTest { */ inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) { val consumed = AtomicBoolean() - addMessageHandler(topic) { msg, reg -> + addMessageHandler(topic) { msg, reg, handler -> removeMessageHandler(reg) check(!consumed.getAndSet(true)) { "Called more than once" } check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" } callback(msg) + handler.afterDatabaseTransaction() } } diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 568f671ad0..9640a8aa10 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -8,6 +8,7 @@ import net.corda.confidential.SwapIdentitiesHandler import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext +import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.sign import net.corda.core.flows.* import net.corda.core.identity.CordaX500Name @@ -47,6 +48,7 @@ import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.events.ScheduledActivityObserver import net.corda.node.services.identity.PersistentIdentityService import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.* import net.corda.node.services.persistence.* @@ -131,7 +133,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, // We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the // low-performance prototyping period. - protected abstract val serverThread: AffinityExecutor + protected abstract val serverThread: AffinityExecutor.ServiceAffinityExecutor private val cordappServices = MutableClassToInstanceMap.create() private val flowFactories = ConcurrentHashMap>, InitiatedFlowFactory<*>>() @@ -248,7 +250,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, flowStarter, servicesForResolution, unfinishedSchedules = busyNodeLatch, - serverThread = serverThread, flowLogicRefFactory = flowLogicRefFactory, drainingModePollPeriod = configuration.drainingModePollPeriod, nodeProperties = nodeProperties) @@ -385,11 +386,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration, protected abstract fun myAddresses(): List protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager { - return StateMachineManagerImpl( + return SingleThreadedStateMachineManager( services, checkpointStorage, serverThread, database, + newSecureRandom(), busyNodeLatch, cordappLoader.appClassLoader ) @@ -894,8 +896,8 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { - override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { - return serverThread.fetchFrom { smm.startFlow(logic, context) } + override fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture> { + return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler) } override fun invokeFlowAsync( diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index 867e9d6c65..4a55d7163a 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -1,42 +1,28 @@ package net.corda.node.services.api -import net.corda.core.crypto.SecureHash +import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.statemachine.FlowStateMachineImpl +import net.corda.node.services.statemachine.Checkpoint +import java.util.stream.Stream /** * Thread-safe storage of fiber checkpoints. */ interface CheckpointStorage { - /** * Add a new checkpoint to the store. */ - fun addCheckpoint(checkpoint: Checkpoint) + fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes) /** - * Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist - * in the store. Doing so will throw an [IllegalArgumentException]. + * Remove existing checkpoint from the store. + * @return whether the id matched a checkpoint that was removed. */ - fun removeCheckpoint(checkpoint: Checkpoint) + fun removeCheckpoint(id: StateMachineRunId): Boolean /** - * Allows the caller to process safely in a thread safe fashion the set of all checkpoints. - * The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management. - * Return false from the block to terminate further iteration. + * Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the + * underlying database connection is open, so any processing should happen before it is closed. */ - fun forEach(block: (Checkpoint) -> Boolean) - -} - -// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). -class Checkpoint(val serializedFiber: SerializedBytes>) { - - val id: SecureHash get() = serializedFiber.hash - - override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id - - override fun hashCode(): Int = id.hashCode() - - override fun toString(): String = "${javaClass.simpleName}(id=$id)" + fun getAllCheckpoints(): Stream>> } diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 94bca7dd04..df04be4140 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -19,6 +19,7 @@ import net.corda.core.utilities.contextLogger import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.cordapp.CordappProviderInternal import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.NetworkMapUpdater import net.corda.node.services.statemachine.FlowStateMachineImpl @@ -135,8 +136,9 @@ interface FlowStarter { /** * Starts an already constructed flow. Note that you must be on the server thread to call this method. * @param context indicates who started the flow, see: [InvocationContext]. + * @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler] */ - fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> + fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture> /** * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index df28f86674..cea738df76 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -26,10 +26,12 @@ import net.corda.node.MutableClock import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.NodePropertiesStore import net.corda.node.services.api.SchedulerService +import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.utilities.PersistentMap import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import org.apache.activemq.artemis.utils.ReusableLatch +import org.apache.mina.util.ConcurrentHashSet import org.slf4j.Logger import java.io.Serializable import java.time.Duration @@ -61,7 +63,6 @@ class NodeSchedulerService(private val clock: CordaClock, private val flowStarter: FlowStarter, private val servicesForResolution: ServicesForResolution, private val unfinishedSchedules: ReusableLatch = ReusableLatch(), - private val serverThread: Executor, private val flowLogicRefFactory: FlowLogicRefFactory, private val nodeProperties: NodePropertiesStore, private val drainingModePollPeriod: Duration, @@ -164,6 +165,10 @@ class NodeSchedulerService(private val clock: CordaClock, var rescheduled: GuavaSettableFuture? = null } + // Used to de-duplicate flow starts in case a flow is starting but the corresponding entry hasn't been removed yet + // from the database + private val startingStateRefs = ConcurrentHashSet() + private val mutex = ThreadBox(InnerState()) // We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow. fun start() { @@ -173,6 +178,29 @@ class NodeSchedulerService(private val clock: CordaClock, } } + /** + * Stop scheduler service. + */ + fun stop() { + mutex.locked { + schedulerTimerExecutor.shutdown() + scheduledStatesQueue.clear() + scheduledStates.clear() + } + } + + /** + * Resume scheduler service after having called [stop]. + */ + fun resume() { + mutex.locked { + schedulerTimerExecutor = Executors.newSingleThreadExecutor() + scheduledStates.putAll(createMap()) + scheduledStatesQueue.addAll(scheduledStates.values) + rescheduleWakeUp() + } + } + override fun scheduleStateActivity(action: ScheduledStateRef) { log.trace { "Schedule $action" } val previousState = scheduledStates[action.ref] @@ -181,7 +209,7 @@ class NodeSchedulerService(private val clock: CordaClock, val previousEarliest = scheduledStatesQueue.peek() scheduledStatesQueue.remove(previousState) scheduledStatesQueue.add(action) - if (previousState == null) { + if (previousState == null && action !in startingStateRefs) { unfinishedSchedules.countUp() } @@ -212,7 +240,7 @@ class NodeSchedulerService(private val clock: CordaClock, } } - private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() + private var schedulerTimerExecutor = Executors.newSingleThreadExecutor() /** * This method first cancels the [java.util.concurrent.Future] for any pending action so that the * [awaitWithDeadline] used below drops through without running the action. We then create a new @@ -254,25 +282,41 @@ class NodeSchedulerService(private val clock: CordaClock, schedulerTimerExecutor.join() } + private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler { + override fun insideDatabaseTransaction() { + scheduledStates.remove(scheduledState.ref) + } + + override fun afterDatabaseTransaction() { + startingStateRefs.remove(scheduledState) + } + + override fun toString(): String { + return "${javaClass.simpleName}($scheduledState)" + } + } + private fun onTimeReached(scheduledState: ScheduledStateRef) { - serverThread.execute { - var flowName: String? = "(unknown)" - try { - database.transaction { - val scheduledFlow = getScheduledFlow(scheduledState) - if (scheduledFlow != null) { - flowName = scheduledFlow.javaClass.name - // TODO refactor the scheduler to store and propagate the original invocation context - val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState)) - val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } - future.then { - unfinishedSchedules.countDown() - } + var flowName: String? = "(unknown)" + try { + // We need to check this before the database transaction, otherwise there is a subtle race between a + // doubly-reached deadline and the removal from [startingStateRefs]. + if (scheduledState !in startingStateRefs) { + val scheduledFlow = database.transaction { getScheduledFlow(scheduledState) } + if (scheduledFlow != null) { + startingStateRefs.add(scheduledState) + flowName = scheduledFlow.javaClass.name + // TODO refactor the scheduler to store and propagate the original invocation context + val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState)) + val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState) + val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture } + future.then { + unfinishedSchedules.countDown() } } - } catch (e: Exception) { - log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e) } + } catch (e: Exception) { + log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e) } } @@ -304,7 +348,6 @@ class NodeSchedulerService(private val clock: CordaClock, } else -> { log.trace { "Scheduler starting FlowLogic $flowLogic" } - scheduledStates.remove(scheduledState.ref) scheduledStatesQueue.remove(scheduledState) flowLogic } @@ -328,4 +371,4 @@ class NodeSchedulerService(private val clock: CordaClock, null } } -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 24620f274e..3e72f52b72 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -1,6 +1,7 @@ package net.corda.node.services.messaging import co.paralleluniverse.fibers.Suspendable +import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient @@ -8,8 +9,8 @@ import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.serialize import net.corda.core.utilities.ByteSequence +import net.corda.node.services.statemachine.DeduplicationId import java.time.Instant -import java.util.* import javax.annotation.concurrent.ThreadSafe /** @@ -35,7 +36,7 @@ interface MessagingService { * * @param topic identifier for the topic to listen for messages arriving on. */ - fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration + fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration /** * Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once @@ -66,8 +67,7 @@ interface MessagingService { message: Message, target: MessageRecipients, retryId: Long? = null, - sequenceKey: Any = target, - additionalHeaders: Map = emptyMap() + sequenceKey: Any = target ) /** A message with a target and sequenceKey specified. */ @@ -97,7 +97,7 @@ interface MessagingService { * @param additionalProperties optional additional message headers. * @param topic identifier for the topic the message is sent to. */ - fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message + fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map = emptyMap()): Message /** Given information about either a specific node or a service returns its corresponding address */ fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients @@ -106,9 +106,8 @@ interface MessagingService { val myAddress: SingleMessageRecipient } - -fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null) - = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId) +fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map = emptyMap()) + = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId) interface MessageHandlerRegistration @@ -127,7 +126,9 @@ interface Message { val topic: String val data: ByteSequence val debugTimestamp: Instant - val uniqueMessageId: String + val uniqueMessageId: DeduplicationId + val senderUUID: String? + val additionalHeaders: Map } // TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that. @@ -138,6 +139,10 @@ interface ReceivedMessage : Message { val peer: CordaX500Name /** Platform version of the sender's node. */ val platformVersion: Int + /** UUID representing the sending JVM */ + val senderSeqNo: Long? + /** True if a flow session init message */ + val isSessionInit: Boolean } /** A singleton that's useful for validating topic strings */ @@ -147,3 +152,29 @@ object TopicStringValidator { fun check(tag: String) = require(regex.matcher(tag).matches()) } +/** + * This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done + * using two hooks that are called from the event processor, one called from the database transaction committing the + * side-effect caused by the event, and another one called after the transaction has committed successfully. + * + * For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later + * deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries. + * + * We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the + * to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping. + */ +interface DeduplicationHandler { + /** + * This will be run inside a database transaction that commits the side-effect of the event, allowing the + * implementor to persist the event delivery fact atomically with the side-effect. + */ + fun insideDatabaseTransaction() + + /** + * This will be run strictly after the side-effect has been committed successfully and may be used for + * cleanup/acknowledgement/stopping of retries. + */ + fun afterDatabaseTransaction() +} + +typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt b/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt new file mode 100644 index 0000000000..9c2e124556 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt @@ -0,0 +1,86 @@ +package net.corda.node.services.messaging + +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.SettableFuture +import com.codahale.metrics.MetricRegistry +import net.corda.core.messaging.MessageRecipients +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.VersionInfo +import net.corda.node.services.statemachine.FlowMessagingImpl +import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders +import org.apache.activemq.artemis.api.core.ActiveMQDuplicateIdException +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.ClientMessage +import org.apache.activemq.artemis.api.core.client.ClientProducer +import org.apache.activemq.artemis.api.core.client.ClientSession +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.ExecutionException +import java.util.concurrent.atomic.AtomicLong +import kotlin.concurrent.thread + +interface AddressToArtemisQueueResolver { + /** + * Resolves a [MessageRecipients] to an Artemis queue name, creating the underlying queue if needed. + */ + fun resolveTargetToArtemisQueue(address: MessageRecipients): String +} + +/** + * The [MessagingExecutor] is responsible for handling send and acknowledge jobs. It batches them using a bounded + * blocking queue, submits the jobs asynchronously and then waits for them to flush using [ClientSession.commit]. + * Note that even though we buffer in theory this shouldn't increase latency as the executor is immediately woken up if + * it was waiting. The number of jobs in the queue is only ever greater than 1 if the commit takes a long time. + */ +class MessagingExecutor( + val session: ClientSession, + val producer: ClientProducer, + val versionInfo: VersionInfo, + val resolver: AddressToArtemisQueueResolver, + val ourSenderUUID: String +) { + private val cordaVendor = SimpleString(versionInfo.vendor) + private val releaseVersion = SimpleString(versionInfo.releaseVersion) + private val ourSenderSeqNo = AtomicLong() + + private companion object { + val log = contextLogger() + val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() + } + + fun send(message: Message, target: MessageRecipients) { + val mqAddress = resolver.resolveTargetToArtemisQueue(target) + val artemisMessage = cordaToArtemisMessage(message) + log.trace { + "Send to: $mqAddress topic: ${message.topic} " + + "sessionID: ${message.topic} id: ${message.uniqueMessageId}" + } + producer.send(SimpleString(mqAddress), artemisMessage) + } + + fun acknowledge(message: ClientMessage) { + message.individualAcknowledge() + } + + internal fun cordaToArtemisMessage(message: Message): ClientMessage? { + return session.createMessage(true).apply { + putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor) + putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion) + putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion) + putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic)) + writeBodyBufferBytes(message.data.bytes) + // Use the magic deduplication property built into Artemis as our message identity too + putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString)) + // If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut. + if (ourSenderUUID == message.senderUUID) { + putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID)) + putLongProperty(P2PMessagingHeaders.senderSeqNo, ourSenderSeqNo.getAndIncrement()) + } + // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended + if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) { + putLongProperty(org.apache.activemq.artemis.api.core.Message.HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) + } + message.additionalHeaders.forEach { key, value -> putStringProperty(key, value) } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt new file mode 100644 index 0000000000..76075477c3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt @@ -0,0 +1,108 @@ +package net.corda.node.services.messaging + +import net.corda.core.crypto.SecureHash +import net.corda.core.identity.CordaX500Name +import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.utilities.AppendOnlyPersistentMap +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX +import java.io.Serializable +import java.time.Instant +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id + +/** + * Encapsulate the de-duplication logic. + */ +class P2PMessageDeduplicator(private val database: CordaPersistence) { + val ourSenderUUID = UUID.randomUUID().toString() + + // A temporary in-memory set of deduplication IDs and associated high water mark details. + // When we receive a message we don't persist the ID immediately, + // so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may + // redeliver messages to the same consumer if they weren't ACKed. + private val beingProcessedMessages = ConcurrentHashMap() + private val processedMessages = createProcessedMessages() + + private fun createProcessedMessages(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toString }, + fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) }, + toPersistentEntity = { key: DeduplicationId, value: MessageMeta -> + ProcessedMessage().apply { + id = key.toString + insertionTime = value.insertionTime + hash = value.senderHash + seqNo = value.senderSeqNo + } + }, + persistentEntityClass = ProcessedMessage::class.java + ) + } + + private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages } + + // We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache. + private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString() + + /** + * @return true if we have seen this message before. + */ + fun isDuplicate(msg: ReceivedMessage): Boolean { + if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) { + return true + } + return isDuplicateInDatabase(msg) + } + + /** + * Called the first time we encounter [deduplicationId]. + */ + fun signalMessageProcessStart(msg: ReceivedMessage) { + val receivedSenderUUID = msg.senderUUID + val receivedSenderSeqNo = msg.senderSeqNo + // We don't want a mix of nulls and values so we ensure that here. + val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null + val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null + beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo) + } + + /** + * Called inside a DB transaction to persist [deduplicationId]. + */ + fun persistDeduplicationId(deduplicationId: DeduplicationId) { + processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!! + } + + /** + * Called after the DB transaction persisting [deduplicationId] committed. + * Any subsequent redelivery will be deduplicated using the DB. + */ + fun signalMessageProcessFinish(deduplicationId: DeduplicationId) { + beingProcessedMessages.remove(deduplicationId) + } + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") + class ProcessedMessage( + @Id + @Column(name = "message_id", length = 64) + var id: String = "", + + @Column(name = "insertion_time") + var insertionTime: Instant = Instant.now(), + + @Column(name = "sender", length = 64) + var hash: String? = "", + + @Column(name = "sequence_number") + var seqNo: Long? = null + ) : Serializable + + private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?) + + private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean) +} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 3fe6a368e4..73fe65e66f 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -1,8 +1,11 @@ package net.corda.node.services.messaging +import co.paralleluniverse.fibers.Suspendable +import com.codahale.metrics.MetricRegistry import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.internal.ThreadBox +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient @@ -21,9 +24,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.statemachine.StateMachineManagerImpl +import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.utilities.AffinityExecutor -import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.PersistentMap import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection @@ -38,7 +40,8 @@ import net.corda.nodeapi.internal.bridging.BridgeEntry import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException -import org.apache.activemq.artemis.api.core.Message.* +import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID +import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.* @@ -50,15 +53,16 @@ import java.io.Serializable import java.security.PublicKey import java.time.Instant import java.util.* -import java.util.concurrent.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CountDownLatch +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit import javax.annotation.concurrent.ThreadSafe import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Id import javax.persistence.Lob -// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox - /** * This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product. * Artemis is a message queue broker and here we run a client connecting to the specified broker instance @@ -85,7 +89,7 @@ import javax.persistence.Lob * @param maxMessageSize A bound applied to the message size. */ @ThreadSafe -class P2PMessagingClient(private val config: NodeConfiguration, +class P2PMessagingClient(val config: NodeConfiguration, private val versionInfo: VersionInfo, private val serverAddress: NetworkHostAndPort, private val myIdentity: PublicKey, @@ -97,26 +101,11 @@ class P2PMessagingClient(private val config: NodeConfiguration, private val maxMessageSize: Int, private val isDrainingModeOn: () -> Boolean, private val drainingModeWasChangedEvents: Observable> -) : SingletonSerializeAsToken(), MessagingService, AutoCloseable { +) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver, AutoCloseable { companion object { private val log = contextLogger() - private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() private const val messageMaxRetryCount: Int = 3 - fun createProcessedMessage(): AppendOnlyPersistentMap { - return AppendOnlyPersistentMap( - toPersistentEntityKey = { it }, - fromPersistentEntity = { Pair(it.uuid, it.insertionTime) }, - toPersistentEntity = { key: String, value: Instant -> - ProcessedMessage().apply { - uuid = key - insertionTime = value - } - }, - persistentEntityClass = ProcessedMessage::class.java - ) - } - fun createMessageToRedeliver(): PersistentMap, RetryMessage, Long> { return PersistentMap( toPersistentEntityKey = { it }, @@ -137,7 +126,7 @@ class P2PMessagingClient(private val config: NodeConfiguration, ) } - private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message { + private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map) : Message { override val debugTimestamp: Instant = Instant.now() override fun toString() = "$topic#${String(data.bytes)}" } @@ -165,32 +154,17 @@ class P2PMessagingClient(private val config: NodeConfiguration, private val scheduledMessageRedeliveries = ConcurrentHashMap>() /** A registration to handle messages of different types */ - data class Handler(val topic: String, - val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration - - private val cordaVendor = SimpleString(versionInfo.vendor) - private val releaseVersion = SimpleString(versionInfo.releaseVersion) - /** An executor for sending messages */ - private val messagingExecutor = AffinityExecutor.ServiceAffinityExecutor("Messaging ${myIdentity.toStringShort()}", 1) + data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress) private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong() private val state = ThreadBox(InnerState()) private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap()) - private val handlers = CopyOnWriteArrayList() - private val processedMessages = createProcessedMessage() + private val handlers = ConcurrentHashMap() - @Entity - @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") - class ProcessedMessage( - @Id - @Column(name = "message_id", length = 64) - var uuid: String = "", - - @Column(name = "insertion_time") - var insertionTime: Instant = Instant.now() - ) : Serializable + private val deduplicator = P2PMessageDeduplicator(database) + internal var messagingExecutor: MessagingExecutor? = null @Entity @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry") @@ -246,6 +220,14 @@ class P2PMessagingClient(private val config: NodeConfiguration, inboxes.forEach { createQueueIfAbsent(it, producerSession!!) } p2pConsumer = P2PMessagingConsumer(inboxes, createNewSession, isDrainingModeOn, drainingModeWasChangedEvents) + messagingExecutor = MessagingExecutor( + producerSession!!, + producer!!, + versionInfo, + this@P2PMessagingClient, + ourSenderUUID = deduplicator.ourSenderUUID + ) + registerBridgeControl(bridgeSession!!, inboxes.toList()) enumerateBridges(bridgeSession!!, inboxes.toList()) } @@ -255,7 +237,9 @@ class P2PMessagingClient(private val config: NodeConfiguration, private fun InnerState.registerBridgeControl(session: ClientSession, inboxes: List) { val bridgeNotifyQueue = "$BRIDGE_NOTIFY.${myIdentity.toStringShort()}" - session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue) + if (!session.queueQuery(SimpleString(bridgeNotifyQueue)).isExists) { + session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue) + } val bridgeConsumer = session.createConsumer(bridgeNotifyQueue) bridgeNotifyConsumer = bridgeConsumer bridgeConsumer.setMessageHandler { msg -> @@ -273,7 +257,7 @@ class P2PMessagingClient(private val config: NodeConfiguration, networkChangeSubscription = networkMap.changed.subscribe { updateBridgesOnNetworkChange(it) } } - private fun sendBridgeControl(message: BridgeControl) { + private fun sendBridgeControl(message: BridgeControl) { state.locked { val controlPacket = message.serialize(context = SerializationDefaults.P2P_CONTEXT).bytes val artemisMessage = producerSession!!.createMessage(false) @@ -343,38 +327,35 @@ class P2PMessagingClient(private val config: NodeConfiguration, private fun resumeMessageRedelivery() { messagesToRedeliver.forEach { retryId, (message, target) -> - sendInternal(message, target, retryId) + send(message, target, retryId) } } private val shutdownLatch = CountDownLatch(1) + var runningFuture = openFuture() + /** * Starts the p2p event loop: this method only returns once [stop] has been called. */ fun run() { - val latch = CountDownLatch(1) try { val consumer = state.locked { check(started) { "start must be called first" } check(!running) { "run can't be called twice" } running = true + runningFuture.set(Unit) // If it's null, it means we already called stop, so return immediately. if (p2pConsumer == null) { return } eventsSubscription = p2pConsumer!!.messages .doOnError { error -> throw error } - .doOnNext { artemisMessage -> - val receivedMessage = artemisToCordaMessage(artemisMessage) - receivedMessage?.let { - deliver(it) - } - artemisMessage.acknowledge() - } + .doOnNext { message -> deliver(message) } // this `run()` method is semantically meant to block until the message consumption runs, hence the latch here .doOnCompleted(latch::countDown) + .doOnError { error -> throw error } .subscribe() p2pConsumer!! } @@ -391,10 +372,13 @@ class P2PMessagingClient(private val config: NodeConfiguration, val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" } val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) } // Use the magic deduplication property built into Artemis as our message identity too - val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) } - log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid") + val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) } + val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID) + val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null + val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE + log.trace { "Received message from: ${message.address} user: $user topic: $topic id: $uniqueMessageId senderUUID: $receivedSenderUUID senderSeqNo: $receivedSenderSeqNo isSessionInit: $isSessionInit" } - return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message) + return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uniqueMessageId, receivedSenderUUID, receivedSenderSeqNo, isSessionInit, message) } catch (e: Exception) { log.error("Unable to process message, ignoring it: $message", e) return null @@ -409,52 +393,57 @@ class P2PMessagingClient(private val config: NodeConfiguration, private class ArtemisReceivedMessage(override val topic: String, override val peer: CordaX500Name, override val platformVersion: Int, - override val uniqueMessageId: String, + override val uniqueMessageId: DeduplicationId, + override val senderUUID: String?, + override val senderSeqNo: Long?, + override val isSessionInit: Boolean, private val message: ClientMessage) : ReceivedMessage { override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) } override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) + override val additionalHeaders: Map = emptyMap() override fun toString() = "$topic#$data" } - private fun deliver(msg: ReceivedMessage): Boolean { - state.checkNotLocked() - // Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added - // or removed whilst the filter is executing will not affect anything. - val deliverTo = handlers.filter { it.topic.isBlank() || it.topic == msg.topic } - try { - // This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will - // be slow, and Artemis can handle that case intelligently. We don't just invoke the handler - // directly in order to ensure that we have the features of the AffinityExecutor class throughout - // the bulk of the codebase and other non-messaging jobs can be scheduled onto the server executor - // easily. - // - // Note that handlers may re-enter this class. We aren't holding any locks and methods like - // start/run/stop have re-entrancy assertions at the top, so it is OK. - nodeExecutor.fetchFrom { - database.transaction { - if (msg.uniqueMessageId in processedMessages) { - log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" } - } else { - if (deliverTo.isEmpty()) { - // TODO: Implement dead letter queue, and send it there. - log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet") - } else { - callHandlers(msg, deliverTo) - } - // TODO We will at some point need to decide a trimming policy for the id's - processedMessages[msg.uniqueMessageId] = Instant.now() - } - } + internal fun deliver(artemisMessage: ClientMessage) { + + artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> + if (!deduplicator.isDuplicate(cordaMessage)) { + deduplicator.signalMessageProcessStart(cordaMessage) + deliver(cordaMessage, artemisMessage) + } else { + log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" } + artemisMessage.individualAcknowledge() } - } catch (e: Exception) { - log.error("Caught exception whilst executing message handler for ${msg.topic}", e) } - return true } - private fun callHandlers(msg: ReceivedMessage, deliverTo: List) { - for (handler in deliverTo) { - handler.callback(msg, handler) + private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) { + + state.checkNotLocked() + val deliverTo = handlers[msg.topic] + if (deliverTo != null) { + try { + deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg)) + } catch (e: Exception) { + log.error("Caught exception whilst executing message handler for ${msg.topic}", e) + } + } else { + log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet") + } + } + + inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler { + override fun insideDatabaseTransaction() { + deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId) + } + + override fun afterDatabaseTransaction() { + deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId) + messagingExecutor!!.acknowledge(artemisMessage) + } + + override fun toString(): String { + return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})" } } @@ -470,6 +459,7 @@ class P2PMessagingClient(private val config: NodeConfiguration, check(started) val prevRunning = running running = false + runningFuture = openFuture() networkChangeSubscription?.unsubscribe() require(p2pConsumer != null, { "stop can't be called twice" }) require(producer != null, { "stop can't be called twice" }) @@ -507,75 +497,42 @@ class P2PMessagingClient(private val config: NodeConfiguration, override fun close() = stop() - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map) { - sendInternal(message, target, retryId, additionalHeaders) - } - - private fun sendInternal(message: Message, target: MessageRecipients, retryId: Long?, additionalHeaders: Map = emptyMap()) { - // We have to perform sending on a different thread pool, since using the same pool for messaging and - // fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. - messagingExecutor.fetchFrom { - state.locked { - val mqAddress = getMQAddress(target) - val artemisMessage = producerSession!!.createMessage(true).apply { - putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor) - putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion) - putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion) - putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic)) - writeBodyBufferBytes(message.data.bytes) - // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId)) - - // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended - if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) { - putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) - } - additionalHeaders.forEach { key, value -> putStringProperty(key, value) } - } - log.trace { - "Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}" - } - sendMessage(mqAddress, artemisMessage) - retryId?.let { - database.transaction { - messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) }) - } - scheduledMessageRedeliveries[it] = messagingExecutor.schedule({ - sendWithRetry(0, mqAddress, artemisMessage, it) - }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) - - } + @Suspendable + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { + messagingExecutor!!.send(message, target) + retryId?.let { + database.transaction { + messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) }) } + scheduledMessageRedeliveries[it] = nodeExecutor.schedule({ + sendWithRetry(0, message, target, retryId) + }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) } } + @Suspendable override fun send(addressedMessages: List) { for ((message, target, retryId, sequenceKey) in addressedMessages) { send(message, target, retryId, sequenceKey) } } - private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { - fun ClientMessage.randomiseDuplicateId() { - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) - } - + private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) { log.trace { "Attempting to retry #$retryCount message delivery for $retryId" } if (retryCount >= messageMaxRetryCount) { - log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $address") + log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target") scheduledMessageRedeliveries.remove(retryId) return } - message.randomiseDuplicateId() - - state.locked { - log.trace { "Retry #$retryCount sending message $message to $address for $retryId" } - sendMessage(address, message) + val messageWithRetryCount = object : Message by message { + override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount") } - scheduledMessageRedeliveries[retryId] = messagingExecutor.schedule({ - sendWithRetry(retryCount + 1, address, message, retryId) + messagingExecutor!!.send(messageWithRetryCount, target) + + scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({ + sendWithRetry(retryCount + 1, message, target, retryId) }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) } @@ -590,18 +547,14 @@ class P2PMessagingClient(private val config: NodeConfiguration, } } - private fun Pair.deliver() = deliver(second!!) - private fun Pair.acknowledge() = first.acknowledge() - - private fun getMQAddress(target: MessageRecipients): String { - return if (target == myAddress) { + override fun resolveTargetToArtemisQueue(address: MessageRecipients): String { + return if (address == myAddress) { // If we are sending to ourselves then route the message directly to our P2P queue. RemoteInboxAddress(myIdentity).queueName } else { // Otherwise we send the message to an internal queue for the target residing on our broker. It's then the // broker's job to route the message to the target's P2P queue. - val internalTargetQueue = (target as? ArtemisAddress)?.queueName - ?: throw IllegalArgumentException("Not an Artemis address") + val internalTargetQueue = (address as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address") state.locked { createQueueIfAbsent(internalTargetQueue, producerSession!!) } @@ -630,24 +583,27 @@ class P2PMessagingClient(private val config: NodeConfiguration, } } - override fun addMessageHandler(topic: String, - callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { + override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." } - val handler = Handler(topic, callback) - handlers.add(handler) - return handler + handlers.compute(topic) { _, handler -> + if (handler != null) { + throw IllegalStateException("Cannot add another acking handler for $topic, there is already an acking one") + } + callback + } + return HandlerRegistration(topic, callback) } override fun removeMessageHandler(registration: MessageHandlerRegistration) { - handlers.remove(registration) + registration as HandlerRegistration + handlers.remove(registration.topic) } - override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message { - // TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying. - return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map): Message { + + return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders) } - // TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port) override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { return when (partyInfo) { is PartyInfo.SingleNode -> NodeAddress(partyInfo.party.owningKey, partyInfo.addresses.single()) @@ -720,4 +676,4 @@ private fun ReactiveArtemisConsumer.switchTo(other: ReactiveArtemisConsumer) { !other.started -> other.start() !other.connected -> other.connect() } -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index f8140b4b50..608c5f35d6 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,11 +1,16 @@ package net.corda.node.services.persistence +import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.api.Checkpoint +import net.corda.core.utilities.debug import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.statemachine.Checkpoint import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.currentDBSession import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY +import org.slf4j.LoggerFactory +import java.util.* +import java.util.stream.Stream import java.io.Serializable import javax.persistence.Column import javax.persistence.Entity @@ -16,6 +21,7 @@ import javax.persistence.Lob * Simple checkpoint key value storage in DB. */ class DBCheckpointStorage : CheckpointStorage { + val log = LoggerFactory.getLogger(this::class.java) @Entity @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints") @@ -29,32 +35,30 @@ class DBCheckpointStorage : CheckpointStorage { var checkpoint: ByteArray = EMPTY_BYTE_ARRAY ) : Serializable - override fun addCheckpoint(checkpoint: Checkpoint) { - currentDBSession().save(DBCheckpoint().apply { - checkpointId = checkpoint.id.toString() - this.checkpoint = checkpoint.serializedFiber.bytes // XXX: Is copying the byte array necessary? + override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes) { + currentDBSession().saveOrUpdate(DBCheckpoint().apply { + checkpointId = id.uuid.toString() + this.checkpoint = checkpoint.bytes + log.debug { "Checkpoint $checkpointId, size=${this.checkpoint.size}" } }) } - override fun removeCheckpoint(checkpoint: Checkpoint) { + override fun removeCheckpoint(id: StateMachineRunId): Boolean { val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) val root = delete.from(DBCheckpoint::class.java) - delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), checkpoint.id.toString())) - session.createQuery(delete).executeUpdate() + delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), id.uuid.toString())) + return session.createQuery(delete).executeUpdate() > 0 } - override fun forEach(block: (Checkpoint) -> Boolean) { + override fun getAllCheckpoints(): Stream>> { val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) val root = criteriaQuery.from(DBCheckpoint::class.java) criteriaQuery.select(root) - for (row in session.createQuery(criteriaQuery).resultList) { - val checkpoint = Checkpoint(SerializedBytes(row.checkpoint)) - if (!block(checkpoint)) { - break - } + return session.createQuery(criteriaQuery).stream().map { + StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes(it.checkpoint) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index e2f298a7e1..9e90665ab9 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -1,15 +1,19 @@ package net.corda.node.services.persistence -import net.corda.core.internal.VisibleForTesting -import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.crypto.TransactionSignature import net.corda.core.internal.ThreadBox +import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed import net.corda.core.serialization.* +import net.corda.core.toFuture import net.corda.core.transactions.CoreTransaction import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.WritableTransactionStorage -import net.corda.node.utilities.* +import net.corda.node.utilities.AppendOnlyPersistentMapBase +import net.corda.node.utilities.WeightBasedAppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction @@ -96,6 +100,17 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S } } + override fun trackTransaction(id: SecureHash): CordaFuture { + return txStorage.locked { + val existingTransaction = get(id) + if (existingTransaction == null) { + updatesPublisher.filter { it.id == id }.toFuture() + } else { + doneFuture(existingTransaction.toSignedTx()) + } + } + } + @VisibleForTesting val transactions: Iterable get() = txStorage.content.allPersisted().map { it.second.toSignedTx() }.toList() diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index eea216eff5..80a8499495 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -14,6 +14,7 @@ import net.corda.node.services.api.SchemaService.SchemaOptions import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.identity.PersistentIdentityService import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.messaging.P2PMessageDeduplicator import net.corda.node.services.messaging.P2PMessagingClient import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBTransactionMappingStorage @@ -43,7 +44,7 @@ class NodeSchemaService(extraSchemas: Set = emptySet(), includeNot PersistentKeyManagementService.PersistentKey::class.java, NodeSchedulerService.PersistentScheduledState::class.java, NodeAttachmentService.DBAttachment::class.java, - P2PMessagingClient.ProcessedMessage::class.java, + P2PMessageDeduplicator.ProcessedMessage::class.java, P2PMessagingClient.RetryMessage::class.java, PersistentIdentityService.PersistentIdentity::class.java, PersistentIdentityService.PersistentIdentityNames::class.java, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt new file mode 100644 index 0000000000..193cdecef1 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -0,0 +1,132 @@ +package net.corda.node.services.statemachine + +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.FlowAsyncOperation +import net.corda.node.services.messaging.DeduplicationHandler +import java.time.Instant + +/** + * [Action]s are reified IO actions to execute as part of state machine transitions. + */ +sealed class Action { + + /** + * Track a transaction hash and notify the state machine once the corresponding transaction has committed. + */ + data class TrackTransaction(val hash: SecureHash) : Action() + + /** + * Send an initial session message to [party]. + */ + data class SendInitial( + val party: Party, + val initialise: InitialSessionMessage, + val deduplicationId: DeduplicationId + ) : Action() + + /** + * Send a session message to a [peerParty] with which we have an established session. + */ + data class SendExisting( + val peerParty: Party, + val message: ExistingSessionMessage, + val deduplicationId: DeduplicationId + ) : Action() + + /** + * Persist the specified [checkpoint]. + */ + data class PersistCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) : Action() + + /** + * Remove the checkpoint corresponding to [id]. + */ + data class RemoveCheckpoint(val id: StateMachineRunId) : Action() + + /** + * Persist the deduplication facts of [deduplicationHandlers]. + */ + data class PersistDeduplicationFacts(val deduplicationHandlers: List) : Action() + + /** + * Acknowledge messages in [deduplicationHandlers]. + */ + data class AcknowledgeMessages(val deduplicationHandlers: List) : Action() + + /** + * Propagate [errorMessages] to [sessions]. + * @param sessions a map from source session IDs to initiated sessions. + */ + data class PropagateErrors( + val errorMessages: List, + val sessions: List + ) : Action() + + /** + * Create a session binding from [sessionId] to [flowId] to allow routing of incoming messages. + */ + data class AddSessionBinding(val flowId: StateMachineRunId, val sessionId: SessionId) : Action() + + /** + * Remove the session bindings corresponding to [sessionIds]. + */ + data class RemoveSessionBindings(val sessionIds: Set) : Action() + + /** + * Signal that the flow corresponding to [flowId] is considered started. + */ + data class SignalFlowHasStarted(val flowId: StateMachineRunId) : Action() + + /** + * Remove the flow corresponding to [flowId]. + */ + data class RemoveFlow( + val flowId: StateMachineRunId, + val removalReason: FlowRemovalReason, + val lastState: StateMachineState + ) : Action() + + /** + * Schedule [event] to self. + */ + data class ScheduleEvent(val event: Event) : Action() + + /** + * Sleep until [time]. + */ + data class SleepUntil(val time: Instant) : Action() + + /** + * Create a new database transaction. + */ + object CreateTransaction : Action() { override fun toString() = "CreateTransaction" } + + /** + * Roll back the current database transaction. + */ + object RollbackTransaction : Action() { override fun toString() = "RollbackTransaction" } + + /** + * Commit the current database transaction. + */ + object CommitTransaction : Action() { override fun toString() = "CommitTransaction" } + + /** + * Execute the specified [operation]. + */ + data class ExecuteAsyncOperation(val operation: FlowAsyncOperation<*>) : Action() +} + +/** + * Reason for flow removal. + */ +sealed class FlowRemovalReason { + data class OrderlyFinish(val flowReturnValue: Any?) : FlowRemovalReason() + data class ErrorFinish(val flowErrors: List) : FlowRemovalReason() + object SoftShutdown : FlowRemovalReason() { override fun toString() = "SoftShutdown" } + // TODO Should we remove errored flows? How will the flow hospital work? Perhaps keep them in memory for a while, flush + // them after a timeout, reload them on flow hospital request. In any case if we ever want to remove them + // (e.g. temporarily) then add a case for that here. +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt new file mode 100644 index 0000000000..7c2bd77fd8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt @@ -0,0 +1,14 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable + +/** + * An executor of a single [Action]. + */ +interface ActionExecutor { + /** + * Execute [action] by [fiber]. + */ + @Suspendable + fun executeAction(fiber: FlowFiber, action: Action) +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt new file mode 100644 index 0000000000..996d5832ad --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -0,0 +1,228 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.Suspendable +import com.codahale.metrics.* +import net.corda.core.internal.concurrent.thenMatch +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.serialize +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.nodeapi.internal.persistence.contextDatabase +import net.corda.nodeapi.internal.persistence.contextTransaction +import net.corda.nodeapi.internal.persistence.contextTransactionOrNull +import java.time.Duration +import java.time.Instant +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong + +/** + * This is the bottom execution engine of flow side-effects. + */ +class ActionExecutorImpl( + private val services: ServiceHubInternal, + private val checkpointStorage: CheckpointStorage, + private val flowMessaging: FlowMessaging, + private val stateMachineManager: StateMachineManagerInternal, + private val checkpointSerializationContext: SerializationContext, + metrics: MetricRegistry +) : ActionExecutor { + + private companion object { + val log = contextLogger() + } + + /** + * This [Gauge] just reports the sum of the bytes checkpointed during the last second. + */ + private class LatchedGauge(private val reservoir: Reservoir) : Gauge { + override fun getValue(): Long { + return reservoir.snapshot.values.sum() + } + } + + private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") + private val checkpointSizesThisSecond = SlidingTimeWindowReservoir(1, TimeUnit.SECONDS) + private val lastBandwidthUpdate = AtomicLong(0) + private val checkpointBandwidthHist = metrics.register("Flows.CheckpointVolumeBytesPerSecondHist", Histogram(SlidingTimeWindowArrayReservoir(1, TimeUnit.DAYS))) + private val checkpointBandwidth = metrics.register("Flows.CheckpointVolumeBytesPerSecondCurrent", LatchedGauge(checkpointSizesThisSecond)) + + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + log.trace { "Flow ${fiber.id} executing $action" } + return when (action) { + is Action.TrackTransaction -> executeTrackTransaction(fiber, action) + is Action.PersistCheckpoint -> executePersistCheckpoint(action) + is Action.PersistDeduplicationFacts -> executePersistDeduplicationIds(action) + is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action) + is Action.PropagateErrors -> executePropagateErrors(action) + is Action.ScheduleEvent -> executeScheduleEvent(fiber, action) + is Action.SleepUntil -> executeSleepUntil(action) + is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action) + is Action.SendInitial -> executeSendInitial(action) + is Action.SendExisting -> executeSendExisting(action) + is Action.AddSessionBinding -> executeAddSessionBinding(action) + is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action) + is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action) + is Action.RemoveFlow -> executeRemoveFlow(action) + is Action.CreateTransaction -> executeCreateTransaction() + is Action.RollbackTransaction -> executeRollbackTransaction() + is Action.CommitTransaction -> executeCommitTransaction() + is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action) + } + } + + @Suspendable + private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { + services.validatedTransactions.trackTransaction(action.hash).thenMatch( + success = { transaction -> + fiber.scheduleEvent(Event.TransactionCommitted(transaction)) + }, + failure = { exception -> + fiber.scheduleEvent(Event.Error(exception)) + } + ) + } + + @Suspendable + private fun executePersistCheckpoint(action: Action.PersistCheckpoint) { + val checkpointBytes = serializeCheckpoint(action.checkpoint) + checkpointStorage.addCheckpoint(action.id, checkpointBytes) + checkpointingMeter.mark() + checkpointSizesThisSecond.update(checkpointBytes.size.toLong()) + var lastUpdateTime = lastBandwidthUpdate.get() + while (System.nanoTime() - lastUpdateTime > TimeUnit.SECONDS.toNanos(1)) { + if (lastBandwidthUpdate.compareAndSet(lastUpdateTime, System.nanoTime())) { + val checkpointVolume = checkpointSizesThisSecond.snapshot.values.sum() + checkpointBandwidthHist.update(checkpointVolume) + } + lastUpdateTime = lastBandwidthUpdate.get() + } + } + + @Suspendable + private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationFacts) { + for (handle in action.deduplicationHandlers) { + handle.insideDatabaseTransaction() + } + } + + @Suspendable + private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) { + action.deduplicationHandlers.forEach { + it.afterDatabaseTransaction() + } + } + + @Suspendable + private fun executePropagateErrors(action: Action.PropagateErrors) { + action.errorMessages.forEach { error -> + val exception = error.flowException + log.debug("Propagating error", exception) + } + for (sessionState in action.sessions) { + // We cannot propagate if the session isn't live. + if (sessionState.initiatedState !is InitiatedSessionState.Live) { + continue + } + // Don't propagate errors to the originating session + for (errorMessage in action.errorMessages) { + val sinkSessionId = sessionState.initiatedState.peerSinkSessionId + val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) + val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) + flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) + } + } + } + + @Suspendable + private fun executeScheduleEvent(fiber: FlowFiber, action: Action.ScheduleEvent) { + fiber.scheduleEvent(action.event) + } + + @Suspendable + private fun executeSleepUntil(action: Action.SleepUntil) { + // TODO introduce explicit sleep state + wakeup event instead of relying on Fiber.sleep. This is so shutdown + // conditions may "interrupt" the sleep instead of waiting until wakeup. + val duration = Duration.between(Instant.now(), action.time) + Fiber.sleep(duration.toNanos(), TimeUnit.NANOSECONDS) + } + + @Suspendable + private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) { + checkpointStorage.removeCheckpoint(action.id) + } + + @Suspendable + private fun executeSendInitial(action: Action.SendInitial) { + flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId) + } + + @Suspendable + private fun executeSendExisting(action: Action.SendExisting) { + flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId) + } + + @Suspendable + private fun executeAddSessionBinding(action: Action.AddSessionBinding) { + stateMachineManager.addSessionBinding(action.flowId, action.sessionId) + } + + @Suspendable + private fun executeRemoveSessionBindings(action: Action.RemoveSessionBindings) { + stateMachineManager.removeSessionBindings(action.sessionIds) + } + + @Suspendable + private fun executeSignalFlowHasStarted(action: Action.SignalFlowHasStarted) { + stateMachineManager.signalFlowHasStarted(action.flowId) + } + + @Suspendable + private fun executeRemoveFlow(action: Action.RemoveFlow) { + stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState) + } + + @Suspendable + private fun executeCreateTransaction() { + if (contextTransactionOrNull != null) { + throw IllegalStateException("Refusing to create a second transaction") + } + contextDatabase.newTransaction() + } + + @Suspendable + private fun executeRollbackTransaction() { + contextTransactionOrNull?.close() + } + + @Suspendable + private fun executeCommitTransaction() { + try { + contextTransaction.commit() + } finally { + contextTransaction.close() + contextTransactionOrNull = null + } + } + + @Suspendable + private fun executeAsyncOperation(fiber: FlowFiber, action: Action.ExecuteAsyncOperation) { + val operationFuture = action.operation.execute() + operationFuture.thenMatch( + success = { result -> + fiber.scheduleEvent(Event.AsyncOperationCompletion(result)) + }, + failure = { exception -> + fiber.scheduleEvent(Event.Error(exception)) + } + ) + } + + private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes { + return checkpoint.serialize(context = checkpointSerializationContext) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt new file mode 100644 index 0000000000..5fe085496e --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt @@ -0,0 +1,66 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.strands.concurrent.AbstractQueuedSynchronizer +import co.paralleluniverse.fibers.Suspendable + +/** + * Quasar-compatible latch that may be incremented. + */ +class CountUpDownLatch(initialValue: Int) { + + // See quasar CountDownLatch + private class Sync(initialValue: Int) : AbstractQueuedSynchronizer() { + init { + state = initialValue + } + + override fun tryAcquireShared(arg: Int): Int { + if (arg >= 0) { + return if (state == arg) 1 else -1 + } else { + return if (state <= -arg) 1 else -1 + } + } + + override fun tryReleaseShared(arg: Int): Boolean { + while (true) { + val c = state + if (c == 0) + return false + val nextc = c - Math.min(c, arg) + if (compareAndSetState(c, nextc)) + return nextc == 0 + } + } + + fun increment() { + while (true) { + val c = state + val nextc = c + 1 + if (compareAndSetState(c, nextc)) + return + } + } + } + + private val sync = Sync(initialValue) + + @Suspendable + fun await() { + sync.acquireSharedInterruptibly(0) + } + + @Suspendable + fun awaitLessThanOrEqual(number: Int) { + sync.acquireSharedInterruptibly(number) + } + + fun countDown(number: Int = 1) { + require(number > 0) + sync.releaseShared(number) + } + + fun countUp() { + sync.increment() + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt new file mode 100644 index 0000000000..7e853dc5a8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt @@ -0,0 +1,47 @@ +package net.corda.node.services.statemachine + +import java.security.SecureRandom + +/** + * A deduplication ID of a flow message. + */ +data class DeduplicationId(val toString: String) { + companion object { + /** + * Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it, + * unless we persist the ID somehow. + */ + fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}") + + /** + * Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of + * creating IDs in case the message-generating flow logic is replayed on hard failure. + * + * A normal deduplication ID consists of: + * 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the + * initiator's session ID. + * 2. The number of *clean* suspends since the start of the flow. + * 3. An optional additional index, for cases where several messages are sent as part of the state transition. + * Note that care must be taken with this index, it must be a deterministic counter. For example a naive + * iteration over a HashMap will produce a different list of indeces than a previous run, causing the + * message-id map to change, which means deduplication will not happen correctly. + */ + fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId { + return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index") + } + + /** + * Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal + * IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the + * dirtiness state to be thrown away for resumption. + * + * An error deduplication ID consists of: + * 1. The error's ID. This is a unique value per "source" of error and is propagated. + * See [net.corda.core.flows.IdentifiableException]. + * 2. The recipient's session ID. + */ + fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId { + return DeduplicationId("E-$errorId-${recipientSessionId.toLong}") + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt new file mode 100644 index 0000000000..344a7df1ef --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -0,0 +1,126 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.SignedTransaction +import net.corda.node.services.messaging.DeduplicationHandler + +/** + * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from + * outside (e.g. in case of message delivery or external event). + */ +sealed class Event { + /** + * Check the current state for pending work. For example if the flow is waiting for a message from a particular + * session this event may cause a flow resume if we have a corresponding message. In general the state machine + * should be idempotent in the [DoRemainingWork] event, meaning a second subsequent event shouldn't modify the state + * or produce [Action]s. + */ + object DoRemainingWork : Event() { override fun toString() = "DoRemainingWork" } + + /** + * Deliver a session message. + * @param sessionMessage the message itself. + * @param deduplicationHandler the handle to acknowledge the message after checkpointing. + * @param sender the sender [Party]. + */ + data class DeliverSessionMessage( + val sessionMessage: ExistingSessionMessage, + val deduplicationHandler: DeduplicationHandler, + val sender: Party + ) : Event() + + /** + * Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error. + * @param exception the exception itself. + */ + data class Error(val exception: Throwable) : Event() + + /** + * Signal that a ledger transaction has committed. This is an event completing a [FlowIORequest.WaitForLedgerCommit] + * suspension. + * @param transaction the transaction that was committed. + */ + data class TransactionCommitted(val transaction: SignedTransaction) : Event() + + /** + * Trigger a soft shutdown, removing the flow as soon as possible. This causes the flow to be removed as soon as + * this event is processed. Note that on restart the flow will resume as normal. + */ + object SoftShutdown : Event() { override fun toString() = "SoftShutdown" } + + /** + * Start error propagation on a errored flow. This may be triggered by e.g. a [FlowHospital]. + */ + object StartErrorPropagation : Event() { override fun toString() = "StartErrorPropagation" } + + /** + * + * Scheduled by the flow. + * + * Initiate a flow. This causes a new session object to be created and returned to the flow. Note that no actual + * communication takes place at this time, only on the first send/receive operation on the session. + * @param party the [Party] to create a session with. + */ + data class InitiateFlow(val party: Party) : Event() + + /** + * Signal the entering into a subflow. + * + * Scheduled and executed by the flow. + * + * @param subFlowClass the [Class] of the subflow, to be used to determine whether it's Initiating or inlined. + */ + data class EnterSubFlow(val subFlowClass: Class>) : Event() + + /** + * Signal the leaving of a subflow. + * + * Scheduled by the flow. + * + */ + object LeaveSubFlow : Event() { override fun toString() = "LeaveSubFlow" } + + /** + * Signal a flow suspension. This causes the flow's stack and the state machine's state together with the suspending + * IO request to be persisted into the database. + * + * Scheduled by the flow and executed inside the park closure. + * + * @param ioRequest the request triggering the suspension. + * @param maySkipCheckpoint indicates whether the persistence may be skipped. + * @param fiber the serialised stack of the flow. + */ + data class Suspend( + val ioRequest: FlowIORequest<*>, + val maySkipCheckpoint: Boolean, + val fiber: SerializedBytes> + ) : Event() { + override fun toString() = + "Suspend(" + + "ioRequest=$ioRequest, " + + "maySkipCheckpoint=$maySkipCheckpoint, " + + "fiber=${fiber.hash}, " + + ")" + } + + /** + * Signals clean flow finish. + * + * Scheduled by the flow. + * + * @param returnValue the return value of the flow. + */ + data class FlowFinish(val returnValue: Any?) : Event() + + /** + * Signals the completion of a [FlowAsyncOperation]. + * + * Scheduling is triggered by the service that completes the future returned by the async operation. + * + * @param returnValue the result of the operation. + */ + data class AsyncOperationCompletion(val returnValue: Any?) : Event() +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt new file mode 100644 index 0000000000..40768c261e --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt @@ -0,0 +1,18 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.transitions.StateMachine + +/** + * An interface wrapping a fiber running a flow. + */ +interface FlowFiber { + val id: StateMachineRunId + val stateMachine: StateMachine + + @Suspendable + fun scheduleEvent(event: Event) + + fun snapshot(): StateMachineState +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt new file mode 100644 index 0000000000..bcd60557df --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt @@ -0,0 +1,18 @@ +package net.corda.node.services.statemachine + +/** + * A flow hospital is a class that is notified when a flow transitions into an error state due to an uncaught exception + * or internal error condition, and when it becomes clean again (e.g. due to a resume). + * Also see [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor]. + */ +interface FlowHospital { + /** + * The flow running in [flowFiber] has errored. + */ + fun flowErrored(flowFiber: FlowFiber) + + /** + * The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume. + */ + fun flowCleaned(flowFiber: FlowFiber) +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt deleted file mode 100644 index 65b24a7046..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ /dev/null @@ -1,121 +0,0 @@ -package net.corda.node.services.statemachine - -import co.paralleluniverse.fibers.Suspendable -import net.corda.core.crypto.SecureHash -import net.corda.core.identity.Party -import java.time.Instant - -interface FlowIORequest { - // This is used to identify where we suspended, in case of message mismatch errors and other things where we - // don't have the original stack trace because it's in a suspended fiber. - val stackTraceInCaseOfProblems: StackSnapshot -} - -interface WaitingRequest : FlowIORequest { - fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean -} - -interface SessionedFlowIORequest : FlowIORequest { - val session: FlowSessionInternal -} - -interface SendRequest : SessionedFlowIORequest { - val message: SessionMessage -} - -interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { - val userReceiveType: Class<*>? - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session -} - -data class SendAndReceive(override val session: FlowSessionInternal, - override val message: SessionMessage, - override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -data class ReceiveOnly( - override val session: FlowSessionInternal, - override val userReceiveType: Class<*>? -) : ReceiveRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -class ReceiveAll(val requests: List) : WaitingRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() - - private fun isComplete(received: LinkedHashMap): Boolean { - return received.keys == requests.map { it.session }.toSet() - } - private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) } - - private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean { - return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage } - } - - @Suspendable - fun suspendAndExpectReceive(suspend: Suspend): Map { - val receivedMessages = LinkedHashMap() - - poll(receivedMessages) - return if (isComplete(receivedMessages)) { - receivedMessages - } else { - suspend(this) - poll(receivedMessages) - if (isComplete(receivedMessages)) { - receivedMessages - } else { - throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" }) - } - } - } - - interface Suspend { - @Suspendable - operator fun invoke(request: FlowIORequest) - } - - @Suspendable - private fun poll(receivedMessages: LinkedHashMap) { - return requests.filter { it.session !in receivedMessages.keys }.forEach { request -> - poll(request)?.let { - receivedMessages[request.session] = RequestMessage(request, it) - } - } - } - - @Suspendable - private fun poll(request: ReceiveRequest): ExistingSessionMessage? { - return request.session.receivedMessages.poll()?.message - } - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant() - - private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session } - - data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage) -} - -data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage -} - -data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowLogicRefFactoryImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowLogicRefFactoryImpl.kt index 9ed350c080..b04c59040a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowLogicRefFactoryImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowLogicRefFactoryImpl.kt @@ -156,4 +156,4 @@ class FlowLogicRefFactoryImpl(private val classloader: ClassLoader) : SingletonS return false } } -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt new file mode 100644 index 0000000000..01a90a40d1 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -0,0 +1,91 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import com.esotericsoftware.kryo.KryoException +import net.corda.core.context.InvocationOrigin +import net.corda.core.flows.FlowException +import net.corda.core.identity.Party +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.serialize +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders +import java.io.NotSerializableException + +/** + * A wrapper interface around flow messaging. + */ +interface FlowMessaging { + /** + * Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to + * listen on the send acknowledgement. + */ + @Suspendable + fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) + + /** + * Start the messaging using the [onMessage] message handler. + */ + fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) +} + +/** + * Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing. + */ +class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { + + companion object { + val log = contextLogger() + + val sessionTopic = "platform.session" + } + + override fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) { + serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, deduplicationHandler -> + onMessage(receivedMessage, deduplicationHandler) + } + } + + @Suspendable + override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) { + log.trace { "Sending message $deduplicationId $message to party $party" } + val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party)) + val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") + val address = serviceHub.networkService.getAddressOfParty(partyInfo) + val sequenceKey = when (message) { + is InitialSessionMessage -> message.initiatorSessionId + is ExistingSessionMessage -> message.recipientSessionId + } + serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey) + } + + private fun SessionMessage.additionalHeaders(target: Party): Map { + + // This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator. + // It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case. + val mightDeadlockDrainingTarget = FlowStateMachineImpl.currentStateMachine()?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name } + return when { + this !is InitialSessionMessage || mightDeadlockDrainingTarget -> emptyMap() + else -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE) + } + } + + private fun serializeSessionMessage(message: SessionMessage): SerializedBytes { + return try { + message.serialize() + } catch (exception: Exception) { + // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. + if ((exception is KryoException || exception is NotSerializableException) + && message is ExistingSessionMessage && message.payload is ErrorSessionMessage) { + val error = message.payload.flowException + val rewrappedError = FlowException(error?.message) + message.copy(payload = message.payload.copy(flowException = rewrappedError)).serialize() + } else { + throw exception + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt index 479bbe86da..a0c3359aca 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt @@ -1,20 +1,40 @@ package net.corda.node.services.statemachine +import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.FlowInfo -import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowSession import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.serialize +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.checkPayloadIs -class FlowSessionImpl(override val counterparty: Party) : FlowSession() { - internal lateinit var stateMachine: FlowStateMachine<*> - internal lateinit var sessionFlow: FlowLogic<*> +class FlowSessionImpl( + override val counterparty: Party, + val sourceSessionId: SessionId +) : FlowSession() { + + override fun toString() = "FlowSessionImpl(counterparty=$counterparty, sourceSessionId=$sourceSessionId)" + + override fun equals(other: Any?): Boolean { + return (other as? FlowSessionImpl)?.sourceSessionId == sourceSessionId + } + + override fun hashCode() = sourceSessionId.hashCode() + + private fun getFlowStateMachine(): FlowStateMachine<*> { + return Fiber.currentFiber() as FlowStateMachine<*> + } @Suspendable override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo { - return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint) + val request = FlowIORequest.GetFlowInfo(NonEmptySet.of(this)) + return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!! } @Suspendable @@ -26,14 +46,15 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { payload: Any, maySkipCheckpoint: Boolean ): UntrustworthyData { - return stateMachine.sendAndReceive( - receiveType, - counterparty, - payload, - sessionFlow, - retrySend = false, - maySkipCheckpoint = maySkipCheckpoint + enforceNotPrimitive(receiveType) + val request = FlowIORequest.SendAndReceive( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = false ) + val responseValues: Map> = getFlowStateMachine().suspend(request, maySkipCheckpoint) + val responseForCurrentSession = responseValues[this]!! + + return responseForCurrentSession.checkPayloadIs(receiveType) } @Suspendable @@ -41,7 +62,9 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { @Suspendable override fun receive(receiveType: Class, maySkipCheckpoint: Boolean): UntrustworthyData { - return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint) + enforceNotPrimitive(receiveType) + val request = FlowIORequest.Receive(NonEmptySet.of(this)) + return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType) } @Suspendable @@ -49,12 +72,17 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { @Suspendable override fun send(payload: Any, maySkipCheckpoint: Boolean) { - return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint) + val request = FlowIORequest.Send( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = false + ) + return getFlowStateMachine().suspend(request, maySkipCheckpoint) } @Suspendable override fun send(payload: Any) = send(payload, maySkipCheckpoint = false) - override fun toString() = "Flow session with $counterparty" + private fun enforceNotPrimitive(type: Class<*>) { + require(!type.isPrimitive) { "Cannot receive primitive type $type" } + } } - diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt deleted file mode 100644 index 58c134e39c..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt +++ /dev/null @@ -1,66 +0,0 @@ -package net.corda.node.services.statemachine - -import net.corda.core.flows.FlowInfo -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSession -import net.corda.core.identity.Party -import net.corda.node.services.statemachine.FlowSessionState.Initiated -import net.corda.node.services.statemachine.FlowSessionState.Initiating -import java.util.concurrent.ConcurrentLinkedQueue - -/** - * @param retryable Indicates that the session initialisation should be retried until an expected [SessionData] response - * is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow - * that only sends back a single [SessionData] message before termination. - */ -// TODO rename this -class FlowSessionInternal( - val flow: FlowLogic<*>, - val flowSession : FlowSession, - val ourSessionId: SessionId, - val initiatingParty: Party?, - var state: FlowSessionState, - var retryable: Boolean = false) { - val receivedMessages = ConcurrentLinkedQueue() - val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> - - override fun toString(): String { - return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)" - } - - fun getPeerSessionId(): SessionId { - val sessionState = state - return when (sessionState) { - is FlowSessionState.Initiated -> sessionState.peerSessionId - else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this") - } - } -} - -data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage) - -/** - * [FlowSessionState] describes the session's state. - * - * [Uninitiated] is pre-handshake, where no communication has happened. [Initiating.otherParty] at this point holds a - * [Party] corresponding to either a specific peer or a service. - * [Initiating] is pre-handshake, where the initiating message has been sent. - * [Initiated] is post-handshake. At this point [Initiating.otherParty] will have been resolved to a specific peer - * [Initiated.peerParty], and the peer's sessionId has been initialised. - */ -sealed class FlowSessionState { - abstract val sendToParty: Party - - data class Uninitiated(val otherParty: Party) : FlowSessionState() { - override val sendToParty: Party get() = otherParty - } - - /** [otherParty] may be a specific peer or a service party */ - data class Initiating(val otherParty: Party) : FlowSessionState() { - override val sendToParty: Party get() = otherParty - } - - data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() { - override val sendToParty: Party get() = peerParty - } -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index f1897204f3..1e14aa0a36 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -5,262 +5,250 @@ import co.paralleluniverse.fibers.Fiber.parkAndSerialize import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand -import com.google.common.primitives.Primitives +import co.paralleluniverse.strands.channels.Channel import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.* -import net.corda.core.internal.concurrent.OpenFuture -import net.corda.core.internal.concurrent.openFuture -import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.serialize -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.* +import net.corda.core.utilities.Try +import net.corda.core.utilities.debug +import net.corda.core.utilities.trace import net.corda.node.services.api.FlowAppAuditEvent import net.corda.node.services.api.FlowPermissionAuditEvent import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.logging.pushToLoggingContext -import net.corda.node.services.statemachine.FlowSessionState.Initiating +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.contextTransaction import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.io.IOException -import java.sql.SQLException -import java.time.Duration -import java.time.Instant -import java.util.* import java.util.concurrent.TimeUnit +import kotlin.reflect.KProperty1 class FlowPermissionException(message: String) : FlowException(message) +class TransientReference(@Transient val value: A) + class FlowStateMachineImpl(override val id: StateMachineRunId, override val logic: FlowLogic, - scheduler: FiberScheduler, - val ourIdentity: Party, - override val context: InvocationContext) : Fiber(id.toString(), scheduler), FlowStateMachine { - + scheduler: FiberScheduler +) : Fiber(id.toString(), scheduler), FlowStateMachine, FlowFiber { companion object { - // Used to work around a small limitation in Quasar. - private val QUASAR_UNBLOCKER = Fiber::class.staticField("SERIALIZER_BLOCKER").value - /** * Return the current [FlowStateMachineImpl] or null if executing outside of one. */ fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*> + + private val log: Logger = LoggerFactory.getLogger("net.corda.flow") } - // These fields shouldn't be serialised, so they are marked @Transient. - @Transient override lateinit var serviceHub: ServiceHubInternal - @Transient override lateinit var ourIdentityAndCert: PartyAndCertificate - @Transient internal lateinit var database: CordaPersistence - @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit - @Transient internal lateinit var actionOnEnd: (Try, Boolean) -> Unit - @Transient internal var fromCheckpoint: Boolean = false - @Transient private var txTrampoline: DatabaseTransaction? = null + override val serviceHub get() = getTransientField(TransientValues::serviceHub) + + data class TransientValues( + val eventQueue: Channel, + val resultFuture: CordaFuture, + val database: CordaPersistence, + val transitionExecutor: TransitionExecutor, + val actionExecutor: ActionExecutor, + val stateMachine: StateMachine, + val serviceHub: ServiceHubInternal, + val checkpointSerializationContext: SerializationContext + ) + + internal var transientValues: TransientReference? = null + internal var transientState: TransientReference? = null + + private fun getTransientField(field: KProperty1): A { + val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!") + return field.get(suppliedValues.value) + } + + private fun extractThreadLocalTransaction(): TransientReference { + val transaction = contextTransaction + contextTransactionOrNull = null + return TransientReference(transaction) + } /** * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message * is not necessary. */ - override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id") - @Transient private var resultFutureTransient: OpenFuture? = openFuture() - private val _resultFuture get() = resultFutureTransient ?: openFuture().also { resultFutureTransient = it } - /** This future will complete when the call method returns. */ - override val resultFuture: CordaFuture get() = _resultFuture - // This state IS serialised, as we need it to know what the fiber is waiting for. - internal val openSessions = HashMap, Party>, FlowSessionInternal>() - internal var waitingForResponse: WaitingRequest? = null + override val logger = log + override val resultFuture: CordaFuture get() = uncheckedCast(getTransientField(TransientValues::resultFuture)) + override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext + override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity internal var hasSoftLockedStates: Boolean = false set(value) { if (value) field = value else throw IllegalArgumentException("Can only set to true") } - init { - logic.stateMachine = this + /** + * Processes an event by creating the associated transition and executing it using the given executor. + * Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately] + * instead. + */ + @Suspendable + private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation { + val stateMachine = getTransientField(TransientValues::stateMachine) + val oldState = transientState!!.value + val actionExecutor = getTransientField(TransientValues::actionExecutor) + val transition = stateMachine.transition(event, oldState) + val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor) + transientState = TransientReference(newState) + return continuation + } + + /** + * Processes the events in the event queue until a transition indicates that control should be returned to user code + * in the form of a regular resume or a throw of an exception. Alternatively the transition may abort the fiber + * completely. + * + * @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the + * processing of the eventloop. Purely used for internal invariant checks. + * @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the eventloop + * processing finished. Purely used for internal invariant checks. + */ + @Suspendable + private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? { + checkDbTransaction(isDbTransactionOpenOnEntry) + val transitionExecutor = getTransientField(TransientValues::transitionExecutor) + val eventQueue = getTransientField(TransientValues::eventQueue) + try { + eventLoop@while (true) { + val nextEvent = eventQueue.receive() + val continuation = processEvent(transitionExecutor, nextEvent) + when (continuation) { + is FlowContinuation.Resume -> return continuation.result + is FlowContinuation.Throw -> { + continuation.throwable.fillInStackTrace() + throw continuation.throwable + } + FlowContinuation.ProcessEvents -> continue@eventLoop + FlowContinuation.Abort -> abortFiber() + } + } + } finally { + checkDbTransaction(isDbTransactionOpenOnExit) + } + } + + /** + * Immediately processes the passed in event. Always called with an open database transaction. + * + * @param event the event to be processed. + * @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the + * processing of the event. Purely used for internal invariant checks. + * @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the event + * processing finished. Purely used for internal invariant checks. + */ + @Suspendable + private fun processEventImmediately(event: Event, isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): FlowContinuation { + checkDbTransaction(isDbTransactionOpenOnEntry) + val transitionExecutor = getTransientField(TransientValues::transitionExecutor) + val continuation = processEvent(transitionExecutor, event) + checkDbTransaction(isDbTransactionOpenOnExit) + return continuation + } + + private fun checkDbTransaction(isPresent: Boolean) { + if (isPresent) { + requireNotNull(contextTransactionOrNull != null) + } else { + require(contextTransactionOrNull == null) + } } @Suspendable override fun run() { - createTransaction() + logic.stateMachine = this + + context.pushToLoggingContext() + + initialiseFlow() + logger.debug { "Calling flow: $logic" } val startTime = System.nanoTime() - val result = try { - val r = logic.call() - // Only sessions which have done a single send and nothing else will block here - openSessions.values - .filter { it.state is Initiating } - .forEach { it.waitForConfirmation() } - r - } catch (e: FlowException) { - recordDuration(startTime, success = false) - // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). - val propagated = e.stackTrace[0].className == javaClass.name - processException(e, propagated) - logger.warn(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e) - return - } catch (t: Throwable) { - recordDuration(startTime, success = false) - logger.warn("Terminated by unexpected exception", t) - processException(t, false) - return + val resultOrError = try { + val result = logic.call() + suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true) + Try.Success(result) + } catch (throwable: Throwable) { + logger.warn("Flow threw exception", throwable) + Try.Failure(throwable) + } + val finalEvent = when (resultOrError) { + is Try.Success -> { + Event.FlowFinish(resultOrError.value) + } + is Try.Failure -> { + Event.Error(resultOrError.exception) + } + } + // Immediately process the last event. This is to make sure the transition can assume that it has an open + // database transaction. + val continuation = processEventImmediately( + finalEvent, + isDbTransactionOpenOnEntry = true, + isDbTransactionOpenOnExit = false + ) + if (continuation == FlowContinuation.ProcessEvents) { + // This can happen in case there was an error and there are further things to do e.g. to propagate it. + processEventsUntilFlowIsResumed( + isDbTransactionOpenOnEntry = false, + isDbTransactionOpenOnExit = false + ) } recordDuration(startTime) - // This is to prevent actionOnEnd being called twice if it throws an exception - actionOnEnd(Try.Success(result), false) - _resultFuture.set(result) - logic.progressTracker?.currentStep = ProgressTracker.DONE - logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" } - } - - private fun createTransaction() { - // Make sure we have a database transaction - database.createTransaction() - logger.trace { "Starting database transaction $contextTransactionOrNull on ${Strand.currentStrand()}" } - } - - private fun processException(exception: Throwable, propagated: Boolean) { - actionOnEnd(Try.Failure(exception), propagated) - _resultFuture.setException(exception) - logic.progressTracker?.endWithError(exception) - } - - internal fun commitTransaction() { - val transaction = contextTransaction - try { - logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." } - transaction.commit() - } catch (e: SQLException) { - // TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly. - logger.error("Transaction commit failed: ${e.message}", e) - System.exit(1) - } finally { - transaction.close() - } } @Suspendable - override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession { - val sessionKey = Pair(sessionFlow, otherParty) - if (openSessions.containsKey(sessionKey)) { - throw IllegalStateException( - "Attempted to initiateFlow() twice in the same InitiatingFlow $sessionFlow for the same party " + - "$otherParty. This isn't supported in this version of Corda. Alternatively you may " + - "initiate a new flow by calling initiateFlow() in an " + - "@${InitiatingFlow::class.java.simpleName} sub-flow." + private fun initialiseFlow() { + processEventsUntilFlowIsResumed( + isDbTransactionOpenOnEntry = false, + isDbTransactionOpenOnExit = true + ) + } + + @Suspendable + override fun subFlow(subFlow: FlowLogic): R { + processEventImmediately( + Event.EnterSubFlow(subFlow.javaClass), + isDbTransactionOpenOnEntry = true, + isDbTransactionOpenOnExit = true + ) + return try { + subFlow.call() + } finally { + processEventImmediately( + Event.LeaveSubFlow, + isDbTransactionOpenOnEntry = true, + isDbTransactionOpenOnExit = true ) } - val flowSession = FlowSessionImpl(otherParty) - createNewSession(otherParty, flowSession, sessionFlow) - flowSession.stateMachine = this - flowSession.sessionFlow = sessionFlow - return flowSession } @Suspendable - override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo { - val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated - return state.context + override fun initiateFlow(party: Party): FlowSession { + val resume = processEventImmediately( + Event.InitiateFlow(party), + isDbTransactionOpenOnEntry = true, + isDbTransactionOpenOnExit = true + ) as FlowContinuation.Resume + return resume.result as FlowSession } @Suspendable - override fun sendAndReceive(receiveType: Class, - otherParty: Party, - payload: Any, - sessionFlow: FlowLogic<*>, - retrySend: Boolean, - maySkipCheckpoint: Boolean): UntrustworthyData { - requireNonPrimitive(receiveType) - logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } - val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - val receivedSessionMessage: ReceivedSessionMessage = if (session == null) { - val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) - // Only do a receive here as the session init has carried the payload - receiveInternal(newSession, receiveType) - } else { - val sendData = createSessionData(session, payload) - sendAndReceiveInternal(session, sendData, receiveType) + private fun abortFiber(): Nothing { + while (true) { + Fiber.park() } - val sessionData = receivedSessionMessage.message.checkDataSessionMessage() - logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" } - return sessionData.checkPayloadIs(receiveType) - } - - private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage { - when (payload) { - is DataSessionMessage -> { - return payload - } - else -> { - throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead") - } - } - } - - @Suspendable - override fun receive(receiveType: Class, - otherParty: Party, - sessionFlow: FlowLogic<*>, - maySkipCheckpoint: Boolean): UntrustworthyData { - requireNonPrimitive(receiveType) - logger.debug { "receive(${receiveType.name}, $otherParty) ..." } - val session = getConfirmedSession(otherParty, sessionFlow) - val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage() - logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" } - return receivedSessionMessage.checkPayloadIs(receiveType) - } - - private fun requireNonPrimitive(receiveType: Class<*>) { - require(!receiveType.isPrimitive) { - "Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class" - } - } - - @Suspendable - override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) { - logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" } - val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - if (session == null) { - // Don't send the payload again if it was already piggy-backed on a session init - initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false) - } else { - sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload))) - } - } - - @Suspendable - override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction { - logger.debug { "waitForLedgerCommit($hash) ..." } - suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>)) - val stx = serviceHub.validatedTransactions.getTransaction(hash) - if (stx != null) { - logger.debug { "Transaction $hash committed to ledger" } - return stx - } - - // If the tx isn't committed then we may have been resumed due to an session ending in an error - for (session in openSessions.values) { - for (receivedMessage in session.receivedMessages) { - if (receivedMessage.message.payload is ErrorSessionMessage) { - session.erroredEnd(receivedMessage.message.payload.flowException) - } - } - } - throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage") - } - - // Provide a mechanism to sleep within a Strand without locking any transactional state. - // This checkpoints, since we cannot undo any database writes up to this point. - @Suspendable - override fun sleepUntil(until: Instant) { - suspend(Sleep(until, this)) } // TODO Dummy implementation of access to application specific permission controls and audit logging @@ -305,242 +293,51 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> { - val requests = ArrayList() - for ((session, receiveType) in sessions) { - val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow) - requests.add(ReceiveOnly(sessionInternal, receiveType)) - } - val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend) - val result = LinkedHashMap>() - for ((sessionInternal, requestAndMessage) in receivedMessages) { - val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session) - result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs( - requestAndMessage.request.userReceiveType as Class - ) - } - return result - } - - internal fun pushToLoggingContext() = context.pushToLoggingContext() - - /** - * This method will suspend the state machine and wait for incoming session init response from other party. - */ - @Suspendable - private fun FlowSessionInternal.waitForConfirmation() { - val sessionInitResponse = receiveInternal(this, null) - val payload = sessionInitResponse.message.payload - when (payload) { - is ConfirmSessionMessage -> { - state = FlowSessionState.Initiated( - sessionInitResponse. - peerParty, - payload.initiatedSessionId, - payload.initiatedFlowInfo) - } - is RejectSessionMessage -> { - throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}") - } - else -> { - throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead") - } - } - } - - private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage { - return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT)) - } - - @Suspendable - private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message)) - - @Suspendable - private fun receiveInternal( - session: FlowSessionInternal, - userReceiveType: Class<*>?): ReceivedSessionMessage { - return waitForMessage(ReceiveOnly(session, userReceiveType)) - } - - @Suspendable - private fun sendAndReceiveInternal( - session: FlowSessionInternal, - message: DataSessionMessage, - userReceiveType: Class<*>?): ReceivedSessionMessage { - val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message) - return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType)) - } - - @Suspendable - private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? { - val session = openSessions[Pair(sessionFlow, otherParty)] ?: return null - return when (session.state) { - is FlowSessionState.Uninitiated -> null - is FlowSessionState.Initiating -> { - session.waitForConfirmation() - session - } - is FlowSessionState.Initiated -> session - } - } - - @Suspendable - private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal { - return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?: - initiateSession(otherParty, sessionFlow, null, waitForConfirmation = true) - } - - private fun createNewSession( - otherParty: Party, - flowSession: FlowSession, - sessionFlow: FlowLogic<*> - ) { - logger.trace { "Creating a new session with $otherParty" } - val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty)) - openSessions[Pair(sessionFlow, otherParty)] = session - } - - @Suspendable - private fun initiateSession( - otherParty: Party, - sessionFlow: FlowLogic<*>, - firstPayload: Any?, - waitForConfirmation: Boolean, - retryable: Boolean = false - ): FlowSessionInternal { - val session = openSessions[Pair(sessionFlow, otherParty)] ?: throw IllegalStateException("Expected an Uninitiated session for $otherParty") - val state = session.state as? FlowSessionState.Uninitiated ?: throw IllegalStateException("Tried to initiate a session $session, but it's already initiating/initiated") - logger.trace { "Initiating a new session with ${state.otherParty}" } - session.state = FlowSessionState.Initiating(state.otherParty) - session.retryable = retryable - val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass - val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT) - logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.") - val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) - sendInternal(session, sessionInit) - if (waitForConfirmation) { - session.waitForConfirmation() - } - return session - } - - @Suspendable - private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage { - val receivedMessage = receiveRequest.suspendAndExpectReceive() - receivedMessage.message.confirmNoError(receiveRequest.session) - return receivedMessage - } - - private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend { - @Suspendable - override fun invoke(request: FlowIORequest) { - suspend(request) - } - } - - @Suspendable - private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage { - val polledMessage = session.receivedMessages.poll() - return if (polledMessage != null) { - if (this is SendAndReceive) { - // Since we've already received the message, we downgrade to a send only to get the payload out and not - // inadvertently block - suspend(SendOnly(session, message)) - } - polledMessage - } else { - // Suspend while we wait for a receive - suspend(this) - session.receivedMessages.poll() ?: - throw IllegalStateException("Was expecting a message but instead got nothing for $this") - } - } - - private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage { - when (payload) { - is ConfirmSessionMessage, - is DataSessionMessage -> { - return this - } - is ErrorSessionMessage -> { - openSessions.values.remove(session) - session.erroredEnd(payload.flowException) - } - is RejectSessionMessage -> { - session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}")) - } - EndSessionMessage -> { - openSessions.values.remove(session) - throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + - "sending data") - } - } - } - - private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing { - if (exception != null) { - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - exception.fillInStackTrace() - throw exception - } else { - throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") - } - } - - @Suspendable - private fun suspend(ioRequest: FlowIORequest) { - // We have to pass the thread local database transaction across via a transient field as the fiber park - // swaps them out. - txTrampoline = contextTransactionOrNull - contextTransactionOrNull = null - if (ioRequest is WaitingRequest) - waitingForResponse = ioRequest - - var exceptionDuringSuspend: Throwable? = null + override fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): R { + val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) + val transaction = extractThreadLocalTransaction() + val transitionExecutor = TransientReference(getTransientField(TransientValues::transitionExecutor)) parkAndSerialize { _, _ -> logger.trace { "Suspended on $ioRequest" } - // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB - try { - contextTransactionOrNull = txTrampoline - txTrampoline = null - actionOnSuspend(ioRequest) - } catch (t: Throwable) { - // Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to - // resume the fiber just so that we can throw it when it's running. - exceptionDuringSuspend = t - logger.trace("Resuming so fiber can it terminate with the exception thrown during suspend process", t) - resume(scheduler) - } - } - if (exceptionDuringSuspend == null && ioRequest is Sleep) { - // Sleep on the fiber. This will not sleep if it's in the past. - Strand.sleep(Duration.between(Instant.now(), ioRequest.until).toNanos(), TimeUnit.NANOSECONDS) + contextTransactionOrNull = transaction.value + val event = try { + Event.Suspend( + ioRequest = ioRequest, + maySkipCheckpoint = maySkipCheckpoint, + fiber = this.serialize(context = serializationContext.value) + ) + } catch (throwable: Throwable) { + Event.Error(throwable) + } + + // We must commit the database transaction before returning from this closure, otherwise Quasar may schedule + // other fibers + val continuation = processEventImmediately( + event, + isDbTransactionOpenOnEntry = true, + isDbTransactionOpenOnExit = false + ) + require(continuation == FlowContinuation.ProcessEvents) + Fiber.unparkDeserialized(this, scheduler) } - createTransaction() - // TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate - // the fiber when exceptions occur inside a suspend. - exceptionDuringSuspend?.let { throw it } - logger.trace { "Resumed from $ioRequest" } + return uncheckedCast(processEventsUntilFlowIsResumed( + isDbTransactionOpenOnEntry = false, + isDbTransactionOpenOnExit = true + )) } - internal fun resume(scheduler: FiberScheduler) { - try { - if (fromCheckpoint) { - logger.info("Resumed from checkpoint") - fromCheckpoint = false - Fiber.unparkDeserialized(this, scheduler) - } else if (state == State.NEW) { - logger.trace("Started") - start() - } else { - Fiber.unpark(this, QUASAR_UNBLOCKER) - } - } catch (t: Throwable) { - logger.error("Error during resume", t) - } + @Suspendable + override fun scheduleEvent(event: Event) { + getTransientField(TransientValues::eventQueue).send(event) } + override fun snapshot(): StateMachineState { + return transientState!!.value + } + + override val stateMachine get() = getTransientField(TransientValues::stateMachine) + /** * Records the duration of this flow – from call() to completion or failure. * Note that the duration will include the time the flow spent being parked, and not just the total @@ -582,15 +379,3 @@ val Class>.appName: String "" } } - -fun DataSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - val payloadData: T = try { - val serializer = SerializationDefaults.SERIALIZATION_FACTORY - serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT) - } catch (ex: Exception) { - throw IOException("Payload invalid", ex) - } - return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: - throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " + - "${payloadData.javaClass.name} (${payloadData})") -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt new file mode 100644 index 0000000000..a31656db5d --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt @@ -0,0 +1,20 @@ +package net.corda.node.services.statemachine + +import net.corda.core.utilities.debug +import net.corda.core.utilities.loggerFor + +/** + * A simple [FlowHospital] implementation that immediately triggers error propagation when a flow dirties. + */ +object PropagatingFlowHospital : FlowHospital { + private val log = loggerFor() + + override fun flowErrored(flowFiber: FlowFiber) { + log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" } + flowFiber.scheduleEvent(Event.StartErrorPropagation) + } + + override fun flowCleaned(flowFiber: FlowFiber) { + throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered") + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt new file mode 100644 index 0000000000..90d4432f0d --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt @@ -0,0 +1,8 @@ +package net.corda.node.services.statemachine + +import net.corda.core.CordaException + +/** + * An exception propagated and thrown in case a session initiation fails. + */ +class SessionRejectException(reason: String) : CordaException(reason) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt new file mode 100644 index 0000000000..4c579c9ba8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -0,0 +1,689 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.FiberExecutorScheduler +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.fibers.instrument.SuspendableHelper +import co.paralleluniverse.strands.channels.Channels +import com.codahale.metrics.Gauge +import net.corda.core.concurrent.CordaFuture +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.castIfPossible +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.messaging.DataFeed +import net.corda.core.serialization.* +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.Try +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.config.shouldCheckCheckpoints +import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.statemachine.interceptors.* +import net.corda.node.services.statemachine.transitions.StateMachine +import net.corda.node.services.statemachine.transitions.StateMachineConfiguration +import net.corda.node.utilities.AffinityExecutor +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl +import net.corda.nodeapi.internal.serialization.withTokenContext +import org.apache.activemq.artemis.utils.ReusableLatch +import rx.Observable +import rx.subjects.PublishSubject +import java.security.SecureRandom +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService +import javax.annotation.concurrent.ThreadSafe +import kotlin.collections.ArrayList +import kotlin.streams.toList + +/** + * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which + * thread actually starts them via [startFlow]. + */ +@ThreadSafe +class SingleThreadedStateMachineManager( + val serviceHub: ServiceHubInternal, + val checkpointStorage: CheckpointStorage, + val executor: ExecutorService, + val database: CordaPersistence, + val secureRandom: SecureRandom, + private val unfinishedFibers: ReusableLatch = ReusableLatch(), + private val classloader: ClassLoader = SingleThreadedStateMachineManager::class.java.classLoader +) : StateMachineManager, StateMachineManagerInternal { + companion object { + private val logger = contextLogger() + } + + private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture) + + // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines + // property. + private class InnerState { + val changesPublisher = PublishSubject.create()!! + // True if we're shutting down, so don't resume anything. + var stopping = false + val flows = HashMap() + val startedFutures = HashMap>() + } + + private val mutex = ThreadBox(InnerState()) + private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor) + // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. + private val liveFibers = ReusableLatch() + // Monitoring support. + private val metrics = serviceHub.monitoringService.metrics + private val sessionToFlow = ConcurrentHashMap() + private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub) + private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null + private val transitionExecutor = makeTransitionExecutor() + + private var checkpointSerializationContext: SerializationContext? = null + private var tokenizableServices: List? = null + private var actionExecutor: ActionExecutor? = null + + override val allStateMachines: List> + get() = mutex.locked { flows.values.map { it.fiber.logic } } + + + private val totalStartedFlows = metrics.counter("Flows.Started") + private val totalFinishedFlows = metrics.counter("Flows.Finished") + + /** + * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number + * which may change across restarts. + * + * We use assignment here so that multiple subscribers share the same wrapped Observable. + */ + override val changes: Observable = mutex.content.changesPublisher + + override fun start(tokenizableServices: List) { + checkQuasarJavaAgentPresence() + this.tokenizableServices = tokenizableServices + val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( + SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + ) + this.checkpointSerializationContext = checkpointSerializationContext + this.actionExecutor = makeActionExecutor(checkpointSerializationContext) + fiberDeserializationChecker?.start(checkpointSerializationContext) + val fibers = restoreFlowsFromCheckpoints() + metrics.register("Flows.InFlight", Gauge { mutex.content.flows.size }) + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) + } + serviceHub.networkMapCache.nodeReady.then { + resumeRestoredFlows(fibers) + flowMessaging.start { receivedMessage, deduplicationHandler -> + executor.execute { + onSessionMessage(receivedMessage, deduplicationHandler) + } + } + } + } + + override fun resume() { + fiberDeserializationChecker?.start(checkpointSerializationContext!!) + val fibers = restoreFlowsFromCheckpoints() + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) + } + serviceHub.networkMapCache.nodeReady.then { + resumeRestoredFlows(fibers) + } + mutex.locked { + stopping = false + } + } + + override fun > findStateMachines(flowClass: Class): List>> { + return mutex.locked { + flows.values.mapNotNull { + flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture } + } + } + } + + /** + * Start the shutdown process, bringing the [SingleThreadedStateMachineManager] to a controlled stop. When this method returns, + * all Fibers have been suspended and checkpointed, or have completed. + * + * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. + */ + override fun stop(allowedUnsuspendedFiberCount: Int) { + require(allowedUnsuspendedFiberCount >= 0) + mutex.locked { + if (stopping) throw IllegalStateException("Already stopping!") + stopping = true + for ((_, flow) in flows) { + flow.fiber.scheduleEvent(Event.SoftShutdown) + } + } + // Account for any expected Fibers in a test scenario. + liveFibers.countDown(allowedUnsuspendedFiberCount) + liveFibers.await() + fiberDeserializationChecker?.let { + val foundUnrestorableFibers = it.stop() + check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." } + } + } + + /** + * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and + * calls to [allStateMachines] + */ + override fun track(): DataFeed>, StateMachineManager.Change> { + return mutex.locked { + DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed()) + } + } + + override fun startFlow( + flowLogic: FlowLogic, + context: InvocationContext, + ourIdentity: Party?, + deduplicationHandler: DeduplicationHandler? + ): CordaFuture> { + return startFlowInternal( + invocationContext = context, + flowLogic = flowLogic, + flowStart = FlowStart.Explicit, + ourIdentity = ourIdentity ?: getOurFirstIdentity(), + deduplicationHandler = deduplicationHandler, + isStartIdempotent = false + ) + } + + override fun killFlow(id: StateMachineRunId): Boolean { + + return mutex.locked { + val flow = flows.remove(id) + if (flow != null) { + logger.debug("Killing flow known to physical node.") + decrementLiveFibers() + totalFinishedFlows.inc() + unfinishedFibers.countDown() + try { + flow.fiber.interrupt() + true + } finally { + database.transaction { + checkpointStorage.removeCheckpoint(id) + } + } + } else { + // TODO replace with a clustered delete after we'll support clustered nodes + logger.debug("Unable to kill a flow unknown to physical node. Might be processed by another physical node.") + false + } + } + } + + override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) { + val previousFlowId = sessionToFlow.put(sessionId, flowId) + if (previousFlowId != null) { + if (previousFlowId == flowId) { + logger.warn("Session binding from $sessionId to $flowId re-added") + } else { + throw IllegalStateException( + "Attempted to add session binding from session $sessionId to flow $flowId, " + + "however there was already a binding to $previousFlowId" + ) + } + } + } + + override fun removeSessionBindings(sessionIds: Set) { + val reRemovedSessionIds = HashSet() + for (sessionId in sessionIds) { + val flowId = sessionToFlow.remove(sessionId) + if (flowId == null) { + reRemovedSessionIds.add(sessionId) + } + } + if (reRemovedSessionIds.isNotEmpty()) { + logger.warn("Session binding from $reRemovedSessionIds re-removed") + } + } + + override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) { + mutex.locked { + val flow = flows.remove(flowId) + if (flow != null) { + decrementLiveFibers() + totalFinishedFlows.inc() + unfinishedFibers.countDown() + return when (removalReason) { + is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState) + is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState) + FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown) + } + } else { + logger.warn("Flow $flowId re-finished") + } + } + } + + override fun signalFlowHasStarted(flowId: StateMachineRunId) { + mutex.locked { + startedFutures.remove(flowId)?.set(Unit) + flows[flowId]?.let { flow -> + changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic)) + } + } + } + + private val stateMachineConfiguration = StateMachineConfiguration.default + + private fun checkQuasarJavaAgentPresence() { + check(SuspendableHelper.isJavaAgentActive(), { + """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. + #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") + }) + } + + private fun decrementLiveFibers() { + liveFibers.countDown() + } + + private fun incrementLiveFibers() { + liveFibers.countUp() + } + + private fun restoreFlowsFromCheckpoints(): List { + return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) -> + // If a flow is added before start() then don't attempt to restore it + mutex.locked { if (flows.containsKey(id)) return@map null } + val checkpoint = deserializeCheckpoint(serializedCheckpoint) + if (checkpoint == null) return@map null + createFlowFromCheckpoint( + id = id, + checkpoint = checkpoint, + initialDeduplicationHandler = null, + isAnyCheckpointPersisted = true, + isStartIdempotent = false + ) + }.toList().filterNotNull() + } + + private fun resumeRestoredFlows(flows: List) { + for (flow in flows) { + addAndStartFlow(flow.fiber.id, flow) + } + } + + private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) { + val peer = message.peer + val sessionMessage = try { + message.data.deserialize() + } catch (ex: Exception) { + logger.error("Received corrupt SessionMessage data from $peer") + deduplicationHandler.afterDatabaseTransaction() + return + } + val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) + if (sender != null) { + when (sessionMessage) { + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender) + } + } else { + logger.error("Unknown peer $peer in $sessionMessage") + } + } + + private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) { + try { + val recipientId = sessionMessage.recipientSessionId + val flowId = sessionToFlow[recipientId] + if (flowId == null) { + deduplicationHandler.afterDatabaseTransaction() + if (sessionMessage.payload is EndSessionMessage) { + logger.debug { + "Got ${EndSessionMessage::class.java.simpleName} for " + + "unknown session $recipientId, discarding..." + } + } else { + throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId") + } + } else { + val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") + flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)) + } + } catch (exception: Exception) { + logger.error("Exception while routing $sessionMessage", exception) + throw exception + } + } + + private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) { + fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { + val errorId = secureRandom.nextLong() + val payload = RejectSessionMessage(message, errorId) + return ExistingSessionMessage(initiatorSessionId, payload) + } + val replyError = try { + val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage) + val initiatedSessionId = SessionId.createRandom(secureRandom) + val senderSession = FlowSessionImpl(sender, initiatedSessionId) + val flowLogic = initiatedFlowFactory.createFlow(senderSession) + val initiatedFlowInfo = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda") + is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName) + } + val senderCoreFlowVersion = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> senderPlatformVersion + is InitiatedFlowFactory.CorDapp -> null + } + startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) + null + } catch (exception: Exception) { + logger.warn("Exception while creating initiated flow", exception) + createErrorMessage( + sessionMessage.initiatorSessionId, + (exception as? SessionRejectException)?.message ?: "Unable to establish session" + ) + } + + if (replyError != null) { + flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) + deduplicationHandler.afterDatabaseTransaction() + } + } + + // TODO this is a temporary hack until we figure out multiple identities + private fun getOurFirstIdentity(): Party { + return serviceHub.myInfo.legalIdentities[0] + } + + private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> { + val initiatingFlowClass = try { + Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java) + } catch (e: ClassNotFoundException) { + throw SessionRejectException("Don't know ${message.initiatorFlowClassName}") + } catch (e: ClassCastException) { + throw SessionRejectException("${message.initiatorFlowClassName} is not a flow") + } + return serviceHub.getFlowFactory(initiatingFlowClass) ?: + throw SessionRejectException("$initiatingFlowClass is not registered") + } + + private fun startInitiatedFlow( + flowLogic: FlowLogic, + initiatingMessageDeduplicationHandler: DeduplicationHandler, + peerSession: FlowSessionImpl, + initiatedSessionId: SessionId, + initiatingMessage: InitialSessionMessage, + senderCoreFlowVersion: Int?, + initiatedFlowInfo: FlowInfo + ) { + val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo) + val ourIdentity = getOurFirstIdentity() + startFlowInternal( + InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, + initiatingMessageDeduplicationHandler, + isStartIdempotent = false + ) + } + + private fun startFlowInternal( + invocationContext: InvocationContext, + flowLogic: FlowLogic, + flowStart: FlowStart, + ourIdentity: Party, + deduplicationHandler: DeduplicationHandler?, + isStartIdempotent: Boolean + ): CordaFuture> { + val flowId = StateMachineRunId.createRandom() + val deduplicationSeed = when (flowStart) { + FlowStart.Explicit -> flowId.uuid.toString() + is FlowStart.Initiated -> + "${flowStart.initiatingMessage.initiatorSessionId.toLong}-" + + "${flowStart.initiatingMessage.initiationEntropy}" + } + + // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties + // have access to the fiber (and thereby the service hub) + val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) + val resultFuture = openFuture() + flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) + flowLogic.stateMachine = flowStateMachineImpl + val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) + + val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow() + val startedFuture = openFuture() + val initialState = StateMachineState( + checkpoint = initialCheckpoint, + pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = false, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = flowLogic + ) + flowStateMachineImpl.transientState = TransientReference(initialState) + mutex.locked { + startedFutures[flowId] = startedFuture + } + totalStartedFlows.inc() + addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture)) + return startedFuture.map { flowStateMachineImpl as FlowStateMachine } + } + + private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes): Checkpoint? { + return try { + serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) + } catch (exception: Throwable) { + logger.error("Encountered unrestorable checkpoint!", exception) + null + } + } + + private fun verifyFlowLogicIsSuspendable(logic: FlowLogic) { + // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's + // easy to forget to add this when creating a new flow, so we check here to give the user a better error. + // + // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which + // forwards to the void method and then returns Unit. However annotations do not get copied across to this + // bridge, so we have to do a more complex scan here. + val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } + if (call.getAnnotation(Suspendable::class.java) == null) { + throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") + } + } + + private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture): FlowStateMachineImpl.TransientValues { + return FlowStateMachineImpl.TransientValues( + eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK), + resultFuture = resultFuture, + database = database, + transitionExecutor = transitionExecutor, + actionExecutor = actionExecutor!!, + stateMachine = StateMachine(id, stateMachineConfiguration, secureRandom), + serviceHub = serviceHub, + checkpointSerializationContext = checkpointSerializationContext!! + ) + } + + private fun createFlowFromCheckpoint( + id: StateMachineRunId, + checkpoint: Checkpoint, + isAnyCheckpointPersisted: Boolean, + isStartIdempotent: Boolean, + initialDeduplicationHandler: DeduplicationHandler? + ): Flow { + val flowState = checkpoint.flowState + val resultFuture = openFuture() + val fiber = when (flowState) { + is FlowState.Unstarted -> { + val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = logic + ) + val fiber = FlowStateMachineImpl(id, logic, scheduler) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber + } + is FlowState.Started -> { + val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = fiber.logic + ) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber + } + } + + verifyFlowLogicIsSuspendable(fiber.logic) + + return Flow(fiber, resultFuture) + } + + private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) { + val checkpoint = flow.fiber.snapshot().checkpoint + for (sessionId in getFlowSessionIds(checkpoint)) { + sessionToFlow.put(sessionId, id) + } + mutex.locked { + if (stopping) { + startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping")) + logger.trace("Not resuming as SMM is stopping.") + } else { + incrementLiveFibers() + unfinishedFibers.countUp() + flows.put(id, flow) + flow.fiber.scheduleEvent(Event.DoRemainingWork) + when (checkpoint.flowState) { + is FlowState.Unstarted -> { + flow.fiber.start() + } + is FlowState.Started -> { + Fiber.unparkDeserialized(flow.fiber, scheduler) + } + } + } + } + } + + private fun getFlowSessionIds(checkpoint: Checkpoint): Set { + val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated + return if (initiatedFlowStart == null) { + checkpoint.sessions.keys + } else { + checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId + } + } + + private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor { + return ActionExecutorImpl( + serviceHub, + checkpointStorage, + flowMessaging, + this, + checkpointSerializationContext, + metrics + ) + } + + private fun makeTransitionExecutor(): TransitionExecutor { + val interceptors = ArrayList() + interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) } + if (serviceHub.configuration.devMode) { + interceptors.add { DumpHistoryOnErrorInterceptor(it) } + } + if (serviceHub.configuration.shouldCheckCheckpoints()) { + interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) } + } + if (logger.isDebugEnabled) { + interceptors.add { PrintingInterceptor(it) } + } + val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database) + return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) } + } + + private fun InnerState.removeFlowOrderly( + flow: Flow, + removalReason: FlowRemovalReason.OrderlyFinish, + lastState: StateMachineState + ) { + drainFlowEventQueue(flow) + // final sanity checks + require(lastState.pendingDeduplicationHandlers.isEmpty()) + require(lastState.isRemoved) + require(lastState.checkpoint.subFlowStack.size == 1) + sessionToFlow.none { it.value == flow.fiber.id } + flow.resultFuture.set(removalReason.flowReturnValue) + lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue))) + } + + private fun InnerState.removeFlowError( + flow: Flow, + removalReason: FlowRemovalReason.ErrorFinish, + lastState: StateMachineState + ) { + drainFlowEventQueue(flow) + val flowError = removalReason.flowErrors[0] // TODO what to do with several? + val exception = flowError.exception + (exception as? FlowException)?.originalErrorId = flowError.errorId + flow.resultFuture.setException(exception) + lastState.flowLogic.progressTracker?.endWithError(exception) + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure(exception))) + } + + // The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here. + private fun drainFlowEventQueue(flow: Flow) { + while (true) { + val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return + when (event) { + is Event.DoRemainingWork -> {} + is Event.DeliverSessionMessage -> { + // Acknowledge the message so it doesn't leak in the broker. + event.deduplicationHandler.afterDatabaseTransaction() + when (event.sessionMessage.payload) { + EndSessionMessage -> { + logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" } + } + else -> { + logger.warn("Unhandled message ${event.sessionMessage} due to flow shutting down") + } + } + } + else -> { + logger.warn("Unhandled event $event due to flow shutting down") + } + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index fcddc980ea..7d645407a1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -1,11 +1,14 @@ package net.corda.node.services.statemachine import net.corda.core.concurrent.CordaFuture -import net.corda.core.flows.FlowLogic -import net.corda.core.internal.FlowStateMachine import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.DataFeed import net.corda.core.utilities.Try +import net.corda.node.services.messaging.DeduplicationHandler import rx.Observable /** @@ -23,7 +26,6 @@ import rx.Observable * TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk * TODO: Timeouts * TODO: Surfacing of exceptions via an API and/or management UI - * TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt * TODO: Don't store all active flows in memory, load from the database on demand. */ interface StateMachineManager { @@ -37,13 +39,25 @@ interface StateMachineManager { */ fun stop(allowedUnsuspendedFiberCount: Int) + /** + * Resume state machine manager after having called [stop]. + */ + fun resume() + /** * Starts a new flow. * * @param flowLogic The flow's code. * @param context The context of the flow. + * @param ourIdentity The identity to use for the flow. + * @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler]. */ - fun startFlow(flowLogic: FlowLogic, context: InvocationContext): CordaFuture> + fun startFlow( + flowLogic: FlowLogic, + context: InvocationContext, + ourIdentity: Party?, + deduplicationHandler: DeduplicationHandler? + ): CordaFuture> /** * Represents an addition/removal of a state machine. @@ -73,4 +87,20 @@ interface StateMachineManager { * Returns all currently live flows. */ val allStateMachines: List> -} \ No newline at end of file + + /** + * Attempts to kill a flow. This is not a clean termination and should be reserved for exceptional cases such as stuck fibers. + * + * @return whether the flow existed and was killed. + */ + fun killFlow(id: StateMachineRunId): Boolean +} + +// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call +// these functions again +interface StateMachineManagerInternal { + fun signalFlowHasStarted(flowId: StateMachineRunId) + fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) + fun removeSessionBindings(sessionIds: Set) + fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt deleted file mode 100644 index e848df74ae..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt +++ /dev/null @@ -1,666 +0,0 @@ -package net.corda.node.services.statemachine - -import co.paralleluniverse.fibers.Fiber -import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.fibers.instrument.SuspendableHelper -import co.paralleluniverse.strands.Strand -import com.codahale.metrics.Gauge -import com.esotericsoftware.kryo.KryoException -import com.google.common.collect.HashMultimap -import com.google.common.util.concurrent.MoreExecutors -import net.corda.core.CordaException -import net.corda.core.concurrent.CordaFuture -import net.corda.core.context.InvocationContext -import net.corda.core.context.InvocationOrigin -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.newSecureRandom -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowInfo -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.StateMachineRunId -import net.corda.core.identity.Party -import net.corda.core.internal.* -import net.corda.core.internal.concurrent.doneFuture -import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT -import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize -import net.corda.core.utilities.Try -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.debug -import net.corda.core.utilities.trace -import net.corda.node.internal.InitiatedFlowFactory -import net.corda.node.services.api.Checkpoint -import net.corda.node.services.api.CheckpointStorage -import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.services.config.shouldCheckCheckpoints -import net.corda.node.services.messaging.ReceivedMessage -import net.corda.node.utilities.AffinityExecutor -import net.corda.node.utilities.newNamedSingleThreadExecutor -import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders -import net.corda.nodeapi.internal.persistence.CordaPersistence -import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit -import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction -import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl -import net.corda.nodeapi.internal.serialization.withTokenContext -import org.apache.activemq.artemis.utils.ReusableLatch -import org.slf4j.Logger -import rx.Observable -import rx.subjects.PublishSubject -import java.io.NotSerializableException -import java.util.* -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.TimeUnit.SECONDS -import javax.annotation.concurrent.ThreadSafe - -/** - * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which - * thread actually starts them via [startFlow]. - */ -@ThreadSafe -class StateMachineManagerImpl( - val serviceHub: ServiceHubInternal, - val checkpointStorage: CheckpointStorage, - val executor: AffinityExecutor, - val database: CordaPersistence, - private val unfinishedFibers: ReusableLatch = ReusableLatch(), - private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader -) : StateMachineManager { - inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) - - companion object { - private val logger = contextLogger() - internal val sessionTopic = "platform.session" - - init { - Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> - (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) - } - } - } - - // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines - // property. - private class InnerState { - var started = false - val stateMachines = LinkedHashMap, Checkpoint>() - val changesPublisher = PublishSubject.create()!! - val fibersWaitingForLedgerCommit = HashMultimap.create>()!! - - fun notifyChangeObservers(change: StateMachineManager.Change) { - changesPublisher.bufferUntilDatabaseCommit().onNext(change) - } - } - - private val scheduler = FiberScheduler() - private val mutex = ThreadBox(InnerState()) - // This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore. - private val checkpointCheckerThread = if (serviceHub.configuration.shouldCheckCheckpoints()) { - newNamedSingleThreadExecutor("CheckpointChecker") - } else { - null - } - - @Volatile private var unrestorableCheckpoints = false - - // True if we're shutting down, so don't resume anything. - @Volatile private var stopping = false - // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. - private val liveFibers = ReusableLatch() - - // Monitoring support. - private val metrics = serviceHub.monitoringService.metrics - - init { - metrics.register("Flows.InFlight", Gauge { mutex.content.stateMachines.size }) - } - - private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") - private val totalStartedFlows = metrics.counter("Flows.Started") - private val totalFinishedFlows = metrics.counter("Flows.Finished") - - private val openSessions = ConcurrentHashMap() - private val recentlyClosedSessions = ConcurrentHashMap() - - // Context for tokenized services in checkpoints - private lateinit var tokenizableServices: List - private val serializationContext by lazy { - SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) - } - - /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ - override fun > findStateMachines(flowClass: Class): List>> { - return mutex.locked { - stateMachines.keys.mapNotNull { - flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast, FlowStateMachineImpl<*>>(it.stateMachine).resultFuture } - } - } - } - - override val allStateMachines: List> - get() = mutex.locked { stateMachines.keys.map { it.logic } } - - /** - * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number - * which may change across restarts. - * - * We use assignment here so that multiple subscribers share the same wrapped Observable. - */ - override val changes: Observable = mutex.content.changesPublisher.wrapWithDatabaseTransaction() - - override fun start(tokenizableServices: List) { - this.tokenizableServices = tokenizableServices - checkQuasarJavaAgentPresence() - restoreFibersFromCheckpoints() - listenToLedgerTransactions() - serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) } - } - - private fun checkQuasarJavaAgentPresence() { - check(SuspendableHelper.isJavaAgentActive(), { - """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. - #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") - }) - } - - private fun listenToLedgerTransactions() { - // Observe the stream of committed, validated transactions and resume fibers that are waiting for them. - serviceHub.validatedTransactions.updates.subscribe { stx -> - val hash = stx.id - val fibers: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } - if (fibers.isNotEmpty()) { - executor.executeASAP { - for (fiber in fibers) { - fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" } - fiber.waitingForResponse = null - resumeFiber(fiber) - } - } - } - } - } - - private fun decrementLiveFibers() { - liveFibers.countDown() - } - - private fun incrementLiveFibers() { - liveFibers.countUp() - } - - /** - * Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns, - * all Fibers have been suspended and checkpointed, or have completed. - * - * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. - */ - override fun stop(allowedUnsuspendedFiberCount: Int) { - require(allowedUnsuspendedFiberCount >= 0) - mutex.locked { - if (stopping) throw IllegalStateException("Already stopping!") - stopping = true - } - // Account for any expected Fibers in a test scenario. - liveFibers.countDown(allowedUnsuspendedFiberCount) - liveFibers.await() - checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) } - check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." } - scheduler.shutdown() - } - - /** - * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and - * calls to [allStateMachines] - */ - override fun track(): DataFeed>, StateMachineManager.Change> { - return mutex.locked { - DataFeed(stateMachines.keys.map { it.logic }, changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) - } - } - - private fun restoreFibersFromCheckpoints() { - mutex.locked { - checkpointStorage.forEach { checkpoint -> - // If a flow is added before start() then don't attempt to restore it - if (!stateMachines.containsValue(checkpoint)) { - deserializeFiber(checkpoint, logger)?.let { - initFiber(it) - stateMachines[it] = checkpoint - } - } - true - } - } - } - - private fun resumeRestoredFibers() { - mutex.locked { - started = true - stateMachines.keys.forEach { resumeRestoredFiber(it) } - } - serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ -> - executor.checkOnThread() - onSessionMessage(message) - } - } - - private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { - fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } - val waitingForResponse = fiber.waitingForResponse - if (waitingForResponse != null) { - if (waitingForResponse is WaitForLedgerCommit) { - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash) - } - if (stx != null) { - fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed") - fiber.waitingForResponse = null - resumeFiber(fiber) - } else { - fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}") - mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) } - } - } else { - fiber.logger.info("Restored, pending on receive") - } - } else { - resumeFiber(fiber) - } - } - - private fun onSessionMessage(message: ReceivedMessage) { - val peer = message.peer - val sessionMessage = try { - message.data.deserialize() - } catch (ex: Exception) { - logger.error("Received corrupt SessionMessage data from $peer") - return - } - val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) - if (sender != null) { - when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) - is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender) - } - } else { - logger.error("Unknown peer $peer in $sessionMessage") - } - } - - private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { - val session = openSessions[message.recipientSessionId] - if (session != null) { - session.fiber.pushToLoggingContext() - session.fiber.logger.trace { "Received $message on $session from $sender" } - if (session.retryable) { - if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) { - session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" } - return - } - if (message.payload !is ConfirmSessionMessage) { - serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong) - } - } - if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) { - openSessions.remove(message.recipientSessionId) - } - session.receivedMessages += ReceivedSessionMessage(sender, message) - if (resumeOnMessage(message, session)) { - // It's important that we reset here and not after the fiber's resumed, in case we receive another message - // before then. - session.fiber.waitingForResponse = null - updateCheckpoint(session.fiber) - session.fiber.logger.trace { "Resuming due to $message" } - resumeFiber(session.fiber) - } - } else { - val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) - if (peerParty != null) { - if (message.payload is ConfirmSessionMessage) { - logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage)) - } else { - logger.trace { "Ignoring session end message for already closed session: $message" } - } - } else { - logger.warn("Received a session message for unknown session: $message, from $sender") - } - } - } - - // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger - // commit but a counterparty flow has ended with an error (in which case our flow also has to end) - private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { - val waitingForResponse = session.fiber.waitingForResponse - return waitingForResponse?.shouldResume(message, session) ?: false - } - - private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) { - - logger.trace { "Received $sessionInit from $sender" } - val senderSessionId = sessionInit.initiatorSessionId - - fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, errorId = sessionInit.initiatorSessionId.toLong))) - - val (session, initiatedFlowFactory) = try { - val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) - val flowSession = FlowSessionImpl(sender) - val flow = initiatedFlowFactory.createFlow(flowSession) - val senderFlowVersion = when (initiatedFlowFactory) { - is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version - is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion - } - val session = FlowSessionInternal( - flow, - flowSession, - SessionId.createRandom(newSecureRandom()), - sender, - FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) - if (sessionInit.firstPayload != null) { - session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload))) - } - openSessions[session.ourSessionId] = session - val context = InvocationContext.peer(sender.name) - val fiber = createFiber(flow, context) - fiber.pushToLoggingContext() - logger.info("Accepting flow session from party ${sender.name}. Session id for tracing purposes is ${sessionInit.initiatorSessionId}.") - flowSession.sessionFlow = flow - flowSession.stateMachine = fiber - fiber.openSessions[Pair(flow, sender)] = session - updateCheckpoint(fiber) - session to initiatedFlowFactory - } catch (e: SessionRejectException) { - logger.warn("${e.logMessage}: $sessionInit") - sendSessionReject(e.rejectMessage) - return - } catch (e: Exception) { - logger.warn("Couldn't start flow session from $sessionInit", e) - sendSessionReject("Unable to establish session") - return - } - - val (ourFlowVersion, appName) = when (initiatedFlowFactory) { - // The flow version for the core flows is the platform version - is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda" - is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName - } - - sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber) - session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" } - session.fiber.logger.trace { "Initiated from $sessionInit on $session" } - resumeFiber(session.fiber) - } - - private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> { - val initiatingFlowClass = try { - Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java) - } catch (e: ClassNotFoundException) { - throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}") - } catch (e: ClassCastException) { - throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow") - } - return serviceHub.getFlowFactory(initiatingFlowClass) ?: - throw SessionRejectException("$initiatingFlowClass is not registered") - } - - private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { - return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) - } - - private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { - return try { - checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { - fromCheckpoint = true - } - } catch (t: Throwable) { - logger.error("Encountered unrestorable checkpoint!", t) - null - } - } - - private fun createFiber(logic: FlowLogic, context: InvocationContext, ourIdentity: Party? = null): FlowStateMachineImpl { - val fsm = FlowStateMachineImpl( - StateMachineRunId.createRandom(), - logic, - scheduler, - ourIdentity ?: serviceHub.myInfo.legalIdentities[0], - context) - initFiber(fsm) - return fsm - } - - private fun initFiber(fiber: FlowStateMachineImpl<*>) { - verifyFlowLogicIsSuspendable(fiber.logic) - fiber.database = database - fiber.serviceHub = serviceHub - fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity } - ?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity.name}) is not one of ours!") - fiber.actionOnSuspend = { ioRequest -> - updateCheckpoint(fiber) - // We commit on the fibers transaction that was copied across ThreadLocals during suspend - // This will free up the ThreadLocal so on return the caller can carry on with other transactions - fiber.commitTransaction() - processIORequest(ioRequest) - decrementLiveFibers() - } - fiber.actionOnEnd = { result, propagated -> - try { - mutex.locked { - stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } - notifyChangeObservers(StateMachineManager.Change.Removed(fiber.logic, result)) - } - endAllFiberSessions(fiber, result, propagated) - } finally { - fiber.commitTransaction() - decrementLiveFibers() - totalFinishedFlows.inc() - unfinishedFibers.countDown() - } - } - mutex.locked { - totalStartedFlows.inc() - unfinishedFibers.countUp() - notifyChangeObservers(StateMachineManager.Change.Add(fiber.logic)) - } - } - - private fun verifyFlowLogicIsSuspendable(logic: FlowLogic) { - // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's - // easy to forget to add this when creating a new flow, so we check here to give the user a better error. - // - // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which - // forwards to the void method and then returns Unit. However annotations do not get copied across to this - // bridge, so we have to do a more complex scan here. - val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } - if (call.getAnnotation(Suspendable::class.java) == null) { - throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") - } - } - - private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) { - openSessions.values.removeIf { session -> - if (session.fiber == fiber) { - session.endSession(fiber.context, (result as? Try.Failure)?.exception, propagated) - true - } else { - false - } - } - } - - private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) { - val initiatedState = state as? FlowSessionState.Initiated ?: return - val sessionEnd = if (exception == null) { - EndSessionMessage - } else { - val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { - // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only - // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with. - exception - } else { - null - } - ErrorSessionMessage(errorResponse, 0) - } - sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber) - recentlyClosedSessions[ourSessionId] = initiatedState.peerParty - } - - /** - * Kicks off a brand new state machine of the given class. - * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is - * restarted with checkpointed state machines in the storage service. - * - * Note that you must be on the [executor] thread. - */ - override fun startFlow(flowLogic: FlowLogic, context: InvocationContext): CordaFuture> { - // TODO: Check that logic has @Suspendable on its call method. - executor.checkOnThread() - val fiber = database.transaction { - val fiber = createFiber(flowLogic, context) - updateCheckpoint(fiber) - fiber - } - // If we are not started then our checkpoint will be picked up during start - mutex.locked { - if (started) { - resumeFiber(fiber) - } - } - return doneFuture(fiber) - } - - private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) { - check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } - val newCheckpoint = Checkpoint(serializeFiber(fiber)) - val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) } - if (previousCheckpoint != null) { - checkpointStorage.removeCheckpoint(previousCheckpoint) - } - checkpointStorage.addCheckpoint(newCheckpoint) - checkpointingMeter.mark() - - checkpointCheckerThread?.execute { - // Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have - // in our testing by failing any test where unrestorable checkpoints are created. - if (deserializeFiber(newCheckpoint, fiber.logger) == null) { - unrestorableCheckpoints = true - } - } - } - - private fun resumeFiber(fiber: FlowStateMachineImpl<*>) { - // Avoid race condition when setting stopping to true and then checking liveFibers - incrementLiveFibers() - if (!stopping) { - executor.executeASAP { - fiber.resume(scheduler) - } - } else { - fiber.logger.trace("Not resuming as SMM is stopping.") - decrementLiveFibers() - } - } - - private fun processIORequest(ioRequest: FlowIORequest) { - executor.checkOnThread() - when (ioRequest) { - is SendRequest -> processSendRequest(ioRequest) - is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest) - is Sleep -> processSleepRequest(ioRequest) - } - } - - private fun processSendRequest(ioRequest: SendRequest) { - val retryId = if (ioRequest.message is InitialSessionMessage) { - with(ioRequest.session) { - openSessions[ourSessionId] = this - if (retryable) ourSessionId.toLong else null - } - } else null - sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) - if (ioRequest !is ReceiveRequest) { - // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - resumeFiber(ioRequest.session.fiber) - } - } - - private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) { - // Is it already committed? - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(ioRequest.hash) - } - if (stx != null) { - resumeFiber(ioRequest.fiber) - } else { - // No, then register to wait. - // - // We assume this code runs on the server thread, which is the only place transactions are committed - // currently. When we liberalise our threading somewhat, handing of wait requests will need to be - // reworked to make the wait atomic in another way. Otherwise there is a race between checking the - // database and updating the waiting list. - mutex.locked { - fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber - } - } - } - - private fun processSleepRequest(ioRequest: Sleep) { - // Resume the fiber now we have checkpointed, so we can sleep on the Fiber. - resumeFiber(ioRequest.fiber) - } - - private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) { - val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) - ?: throw IllegalArgumentException("Don't know about party $party") - val address = serviceHub.networkService.getAddressOfParty(partyInfo) - val logger = fiber?.logger ?: logger - logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" } - - val serialized = try { - message.serialize() - } catch (e: Exception) { - when (e) { - // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. - is KryoException, - is NotSerializableException -> { - if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) { - logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " + - "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) - // The subclass may have overridden toString so we use that - val exMessage = message.payload.flowException.message - message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize() - } else { - throw e - } - } - else -> throw e - } - } - - // This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator. - // It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case. - val additionalHeaders = if (mightDeadlockDrainingSender(fiber, party)) emptyMap() else message.additionalHeaders() - serviceHub.networkService.apply { - send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId, additionalHeaders = additionalHeaders) - } - } - - private fun mightDeadlockDrainingSender(fiber: FlowStateMachineImpl<*>?, target: Party): Boolean { - return fiber?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name } - } -} - -private fun SessionMessage.additionalHeaders(): Map { - return when (this) { - is InitialSessionMessage -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE) - else -> emptyMap() - } -} - -class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) { - constructor(message: String) : this(message, message) -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt new file mode 100644 index 0000000000..b2d572b475 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -0,0 +1,231 @@ +package net.corda.node.services.statemachine + +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.Try +import net.corda.node.services.messaging.DeduplicationHandler + +/** + * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is + * persisted to the database ([Checkpoint]), and the rest, which is an in-memory-only state. + * + * @param checkpoint the persisted part of the state. + * @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user. + * @param pendingDeduplicationHandlers the list of incomplete deduplication handlers. + * @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used + * to make [Event.DoRemainingWork] idempotent. + * @param isTransactionTracked true if a ledger transaction has been tracked as part of a + * [FlowIORequest.WaitForLedgerCommit]. This used is to make tracking idempotent. + * @param isAnyCheckpointPersisted true if at least a single checkpoint has been persisted. This is used to determine + * whether we should DELETE the checkpoint at the end of the flow. + * @param isStartIdempotent true if the start of the flow is idempotent, making the skipping of the initial checkpoint + * possible. + * @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further + * work. + */ +// TODO perhaps add a read-only environment to the state machine for things that don't change over time? +// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do. +data class StateMachineState( + val checkpoint: Checkpoint, + val flowLogic: FlowLogic<*>, + val pendingDeduplicationHandlers: List, + val isFlowResumed: Boolean, + val isTransactionTracked: Boolean, + val isAnyCheckpointPersisted: Boolean, + val isStartIdempotent: Boolean, + val isRemoved: Boolean +) + +/** + * @param invocationContext the initiator of the flow. + * @param ourIdentity the identity the flow is run as. + * @param sessions map of source session ID to session state. + * @param subFlowStack the stack of currently executing subflows. + * @param flowState the state of the flow itself, including the frozen fiber/FlowLogic. + * @param errorState the "dirtiness" state including the involved errors and their propagation status. + * @param numberOfSuspends the number of flow suspends due to IO API calls. + * @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs. + */ +data class Checkpoint( + val invocationContext: InvocationContext, + val ourIdentity: Party, + val sessions: SessionMap, // This must preserve the insertion order! + val subFlowStack: List, + val flowState: FlowState, + val errorState: ErrorState, + val numberOfSuspends: Int, + val deduplicationSeed: String +) { + companion object { + + fun create( + invocationContext: InvocationContext, + flowStart: FlowStart, + flowLogicClass: Class>, + frozenFlowLogic: SerializedBytes>, + ourIdentity: Party, + deduplicationSeed: String + ): Try { + return SubFlow.create(flowLogicClass).map { topLevelSubFlow -> + Checkpoint( + invocationContext = invocationContext, + ourIdentity = ourIdentity, + sessions = emptyMap(), + subFlowStack = listOf(topLevelSubFlow), + flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), + errorState = ErrorState.Clean, + numberOfSuspends = 0, + deduplicationSeed = deduplicationSeed + ) + } + } + } +} + +/** + * The state of a session. + */ +sealed class SessionState { + + /** + * We haven't yet sent the initialisation message + */ + data class Uninitiated( + val party: Party, + val initiatingSubFlow: SubFlow.Initiating + ) : SessionState() + + /** + * We have sent the initialisation message but have not yet received a confirmation. + * @property rejectionError if non-null the initiation failed. + */ + data class Initiating( + val bufferedMessages: List>, + val rejectionError: FlowError? + ) : SessionState() + + /** + * We have received a confirmation, the peer party and session id is resolved. + * @property errors if not empty the session is in an errored state. + */ + data class Initiated( + val peerParty: Party, + val peerFlowInfo: FlowInfo, + val receivedMessages: List, + val initiatedState: InitiatedSessionState, + val errors: List + ) : SessionState() +} + +typealias SessionMap = Map + +/** + * Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest + * of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future + * [FlowInfo] requests. + */ +sealed class InitiatedSessionState { + data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState() + object Ended : InitiatedSessionState() { override fun toString() = "Ended" } +} + +/** + * Represents the way the flow has started. + */ +sealed class FlowStart { + /** + * The flow was started explicitly e.g. through RPC or a scheduled state. + */ + object Explicit : FlowStart() { override fun toString() = "Explicit" } + + /** + * The flow was started implicitly as part of session initiation. + */ + data class Initiated( + val peerSession: FlowSessionImpl, + val initiatedSessionId: SessionId, + val initiatingMessage: InitialSessionMessage, + val senderCoreFlowVersion: Int?, + val initiatedFlowInfo: FlowInfo + ) : FlowStart() { override fun toString() = "Initiated" } +} + +/** + * Represents the user-space related state of the flow. + */ +sealed class FlowState { + + /** + * The flow's unstarted state. We should always be able to start a fresh flow fiber from this datastructure. + * + * @param flowStart How the flow was started. + * @param frozenFlowLogic The serialized user-provided [FlowLogic]. + */ + data class Unstarted( + val flowStart: FlowStart, + val frozenFlowLogic: SerializedBytes> + ) : FlowState() { + override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash})" + } + + /** + * The flow's started state, this means the user-code has suspended on an IO request. + * + * @param flowIORequest what IO request the flow has suspended on. + * @param frozenFiber the serialized fiber itself. + */ + data class Started( + val flowIORequest: FlowIORequest<*>, + val frozenFiber: SerializedBytes> + ) : FlowState() { + override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})" + } +} + +/** + * @param errorId the ID of the error. This is generated once for the source error and is propagated to neighbour + * sessions. + * @param exception the exception itself. Note that this may not contain information about the source error depending + * on whether the source error was a FlowException or otherwise. + */ +data class FlowError(val errorId: Long, val exception: Throwable) + +/** + * The flow's error state. + */ +sealed class ErrorState { + abstract fun addErrors(newErrors: List): ErrorState + + /** + * The flow is in a clean state. + */ + object Clean : ErrorState() { + override fun addErrors(newErrors: List): ErrorState { + return Errored(newErrors, 0, false) + } + override fun toString() = "Clean" + } + + /** + * The flow has dirtied because of an uncaught exception from user code or other error condition during a state + * transition. + * @param errors the list of errors. Multiple errors may be associated with the errored flow e.g. when multiple + * sessions are errored and have been waited on. + * @param propagatedIndex the index of the first error that hasn't yet been propagated. + * @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the + * sessions associated with the flow have been (or about to be) dirtied in counter-flows. + */ + data class Errored( + val errors: List, + val propagatedIndex: Int, + val propagating: Boolean + ) : ErrorState() { + override fun addErrors(newErrors: List): ErrorState { + return copy(errors = errors + newErrors) + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt new file mode 100644 index 0000000000..25f9d228e9 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt @@ -0,0 +1,74 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.utilities.Try + +/** + * A [SubFlow] contains metadata about a currently executing sub-flow. At any point the flow execution is + * characterised with a stack of [SubFlow]s. This stack is used to determine the initiating-initiated flow mapping. + * + * Note that Initiat*ed*ness is an orthogonal property of the top-level subflow, so we don't store any information about + * it here. + */ +sealed class SubFlow { + abstract val flowClass: Class> + + /** + * An inlined subflow. + */ + data class Inlined(override val flowClass: Class>) : SubFlow() + + /** + * An initiating subflow. + * @param [flowClass] the concrete class of the subflow. + * @param [classToInitiateWith] an ancestor class of [flowClass] with the [InitiatingFlow] annotation, to be sent + * to the initiated side. + * @param flowInfo the [FlowInfo] associated with the initiating flow. + */ + data class Initiating( + override val flowClass: Class>, + val classToInitiateWith: Class>, + val flowInfo: FlowInfo + ) : SubFlow() + + companion object { + fun create(flowClass: Class>): Try { + // Are we an InitiatingFlow? + val initiatingAnnotations = getInitiatingFlowAnnotations(flowClass) + return when (initiatingAnnotations.size) { + 0 -> { + Try.Success(Inlined(flowClass)) + } + 1 -> { + val initiatingAnnotation = initiatingAnnotations[0] + val flowContext = FlowInfo(initiatingAnnotation.second.version, flowClass.appName) + Try.Success(Initiating(flowClass, initiatingAnnotation.first, flowContext)) + } + else -> { + Try.Failure(IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated " + + "once, however the following classes all have the annotation: " + + "${initiatingAnnotations.map { it.first }}")) + } + } + } + + private fun getSuperClasses(clazz: Class): List> { + var currentClass: Class? = clazz + val result = ArrayList>() + while (currentClass != null) { + result.add(currentClass) + currentClass = currentClass.superclass + } + return result + } + + private fun getInitiatingFlowAnnotations(flowClass: Class>): List>, InitiatingFlow>> { + return getSuperClasses(flowClass).mapNotNull { clazz -> + val initiatingAnnotation = clazz.getDeclaredAnnotation(InitiatingFlow::class.java) + initiatingAnnotation?.let { Pair(clazz, it) } + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt new file mode 100644 index 0000000000..7bf29e3f14 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt @@ -0,0 +1,25 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult + +/** + * An executor of state machine transitions. This is mostly a wrapper interface around an [ActionExecutor], but can be + * used to create interceptors of transitions. + */ +interface TransitionExecutor { + @Suspendable + fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair +} + +/** + * An interceptor of a transition. These are currently explicitly hooked up in [SingleThreadedStateMachineManager]. + */ +typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt new file mode 100644 index 0000000000..2cf328a450 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt @@ -0,0 +1,67 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.contextDatabase +import net.corda.nodeapi.internal.persistence.contextTransactionOrNull +import java.security.SecureRandom + +/** + * This [TransitionExecutor] runs the transition actions using the passed in [ActionExecutor] and manually dirties the + * state on failure. + * + * If a failure happens when we're already transitioning into a errored state then the transition and the flow fiber is + * completely aborted to avoid error loops. + */ +class TransitionExecutorImpl( + val secureRandom: SecureRandom, + val database: CordaPersistence +) : TransitionExecutor { + private companion object { + val log = contextLogger() + } + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + contextDatabase = database + for (action in transition.actions) { + try { + actionExecutor.executeAction(fiber, action) + } catch (exception: Throwable) { + contextTransactionOrNull?.close() + if (transition.newState.checkpoint.errorState is ErrorState.Errored) { + // If we errored while transitioning to an error state then we cannot record the additional + // error as that may result in an infinite loop, e.g. error propagation fails -> record error -> propagate fails again. + // Instead we just keep around the old error state and wait for a new schedule, perhaps + // triggered from a flow hospital + log.error("Error while executing $action during transition to errored state, aborting transition", exception) + return Pair(FlowContinuation.Abort, previousState.copy(isFlowResumed = false)) + } else { + // Otherwise error the state manually keeping the old flow state and schedule a DoRemainingWork + // to trigger error propagation + log.error("Error while executing $action, erroring state", exception) + val newState = previousState.copy( + checkpoint = previousState.checkpoint.copy( + errorState = previousState.checkpoint.errorState.addErrors( + listOf(FlowError(secureRandom.nextLong(), exception)) + ) + ), + isFlowResumed = false + ) + fiber.scheduleEvent(Event.DoRemainingWork) + return Pair(FlowContinuation.ProcessEvents, newState) + } + } + } + return Pair(transition.continuation, transition.newState) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt new file mode 100644 index 0000000000..86f8e239f0 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt @@ -0,0 +1,51 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.time.Instant +import java.util.concurrent.ConcurrentHashMap + +/** + * This interceptor records a trace of all of the flows' states and transitions. If the flow dirties it dumps the trace + * transition to the logger. + */ +class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : TransitionExecutor { + companion object { + private val log = contextLogger() + } + + private val records = ConcurrentHashMap>() + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation) + val record = records.compute(fiber.id) { _, record -> + (record ?: ArrayList()).apply { add(transitionRecord) } + } + + if (nextState.checkpoint.errorState is ErrorState.Errored) { + log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}") + for (error in nextState.checkpoint.errorState.errors) { + log.warn("Flow ${fiber.id} error", error.exception) + } + } + + if (nextState.isRemoved) { + records.remove(fiber.id) + } + + return Pair(continuation, nextState) + } + +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt new file mode 100644 index 0000000000..2a9b96cc5b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt @@ -0,0 +1,95 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.util.concurrent.LinkedBlockingQueue +import kotlin.concurrent.thread + +/** + * This interceptor checks whether a checkpointed fiber state can be deserialised in a separate thread. + */ +class FiberDeserializationCheckingInterceptor( + val fiberDeserializationChecker: FiberDeserializationChecker, + val delegate: TransitionExecutor +) : TransitionExecutor { + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val previousFlowState = previousState.checkpoint.flowState + val nextFlowState = nextState.checkpoint.flowState + if (nextFlowState is FlowState.Started) { + if (previousFlowState !is FlowState.Started || previousFlowState.frozenFiber != nextFlowState.frozenFiber) { + fiberDeserializationChecker.submitCheck(nextFlowState.frozenFiber) + } + } + return Pair(continuation, nextState) + } +} + +/** + * A fiber deserialisation checker thread. It checks the queued up serialised checkpoints to see if they can be + * deserialised. This is only run in development mode to allow detecting of corrupt serialised checkpoints before they + * are actually used. + */ +class FiberDeserializationChecker { + companion object { + val log = contextLogger() + } + + private sealed class Job { + class Check(val serializedFiber: SerializedBytes>) : Job() + object Finish : Job() + } + + private var checkerThread: Thread? = null + private val jobQueue = LinkedBlockingQueue() + private var foundUnrestorableFibers: Boolean = false + + fun start(checkpointSerializationContext: SerializationContext) { + require(checkerThread == null) + checkerThread = thread(name = "FiberDeserializationChecker") { + while (true) { + val job = jobQueue.take() + when (job) { + is Job.Check -> { + try { + job.serializedFiber.deserialize(context = checkpointSerializationContext) + } catch (throwable: Throwable) { + log.error("Encountered unrestorable checkpoint!", throwable) + foundUnrestorableFibers = true + } + } + Job.Finish -> { + return@thread + } + } + } + } + } + + fun submitCheck(serializedFiber: SerializedBytes>) { + jobQueue.add(Job.Check(serializedFiber)) + } + + /** + * Returns true if some unrestorable checkpoints were encountered, false otherwise + */ + fun stop(): Boolean { + jobQueue.add(Job.Finish) + checkerThread?.join() + checkerThread = null + return foundUnrestorableFibers + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt new file mode 100644 index 0000000000..8ed5f67b95 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt @@ -0,0 +1,46 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.util.concurrent.ConcurrentHashMap + +/** + * This interceptor notifies the passed in [flowHospital] in case a flow went through a clean->errored or a errored->clean + * transition. + */ +class HospitalisingInterceptor( + private val flowHospital: FlowHospital, + private val delegate: TransitionExecutor +) : TransitionExecutor { + private val hospitalisedFlows = ConcurrentHashMap() + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + when (nextState.checkpoint.errorState) { + ErrorState.Clean -> { + if (hospitalisedFlows.remove(fiber.id) != null) { + flowHospital.flowCleaned(fiber) + } + } + is ErrorState.Errored -> { + if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) { + flowHospital.flowErrored(fiber) + } + } + } + if (nextState.isRemoved) { + hospitalisedFlows.remove(fiber.id) + } + return Pair(continuation, nextState) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/MetricInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/MetricInterceptor.kt new file mode 100644 index 0000000000..f5e298fed8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/MetricInterceptor.kt @@ -0,0 +1,24 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import com.codahale.metrics.MetricRegistry +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult + +class MetricInterceptor(val metrics: MetricRegistry, val delegate: TransitionExecutor): TransitionExecutor { + @Suspendable + override fun executeTransition(fiber: FlowFiber, previousState: StateMachineState, event: Event, transition: TransitionResult, actionExecutor: ActionExecutor): Pair { + val metricActionInterceptor = MetricActionInterceptor(metrics, actionExecutor) + return delegate.executeTransition(fiber, previousState, event, transition, metricActionInterceptor) + } +} + +class MetricActionInterceptor(val metrics: MetricRegistry, val delegate: ActionExecutor): ActionExecutor { + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + val context = metrics.timer("Flows.Actions.${action.javaClass.simpleName}").time() + delegate.executeAction(fiber, action) + context.stop() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt new file mode 100644 index 0000000000..a0ca6d6660 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt @@ -0,0 +1,31 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.time.Instant + +/** + * This interceptor simply prints all state machine transitions. Useful for debugging. + */ +class PrintingInterceptor(val delegate: TransitionExecutor) : TransitionExecutor { + companion object { + val log = contextLogger() + } + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation) + log.info("Transition for flow ${fiber.id} $transitionRecord") + return Pair(continuation, nextState) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt new file mode 100644 index 0000000000..1832f6e848 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt @@ -0,0 +1,51 @@ +package net.corda.node.services.statemachine.interceptors + +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.Event +import net.corda.node.services.statemachine.StateMachineState +import net.corda.node.services.statemachine.transitions.TransitionResult +import net.corda.node.utilities.ObjectDiffer +import java.time.Instant + +/** + * This is a diagnostic record that stores information about a state machine transition and provides pretty printing + * by diffing the two states. + */ +data class TransitionDiagnosticRecord( + val timestamp: Instant, + val flowId: StateMachineRunId, + val previousState: StateMachineState, + val nextState: StateMachineState, + val event: Event, + val transition: TransitionResult, + val continuation: FlowContinuation +) { + override fun toString(): String { + val diffIntended = ObjectDiffer.diff(previousState, transition.newState) + val diffNext = ObjectDiffer.diff(previousState, nextState) + return ( + listOf( + "", + " --- Transition of flow $flowId ---", + " Timestamp: $timestamp", + " Event: $event", + " Actions: ", + " ${transition.actions.joinToString("\n ")}", + " Continuation: ${transition.continuation}" + ) + + if (diffIntended != diffNext) { + listOf( + " Diff between previous and intended state:", + "${diffIntended?.toPaths()?.joinToString("")}" + ) + } else { + emptyList() + } + listOf( + + " Diff between previous and next state:", + "${diffNext?.toPaths()?.joinToString("")}" + ) + ).joinToString("\n") + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt new file mode 100644 index 0000000000..074700e3f7 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -0,0 +1,186 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.UnexpectedFlowEndException +import net.corda.node.services.statemachine.* + +/** + * This transition handles incoming session messages. It handles the following cases: + * - DataSessionMessage: these arrive to initiated and confirmed sessions and are expected to be received by the flow. + * - ConfirmSessionMessage: these arrive as a response to an InitialSessionMessage and include information about the + * counterparty flow's session ID as well as their [FlowInfo]. + * - ErrorSessionMessage: these arrive to initiated and confirmed sessions and put the corresponding session into an + * "errored" state. This means that whenever that session is subsequently interacted with the error will be thrown + * in the flow. + * - RejectSessionMessage: these arrive as a response to an InitialSessionMessage when the initiation failed. It + * behaves similarly to ErrorSessionMessage aside from the type of exceptions stored/raised. + * - EndSessionMessage: these are sent when the counterparty flow has finished. They put the corresponding session into + * an "ended" state. This means that subsequent sends on this session will fail, and receives will start failing + * after the buffer of already received messages is drained. + */ +class DeliverSessionMessageTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val event: Event.DeliverSessionMessage +) : Transition { + override fun transition(): TransitionResult { + return builder { + // Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know + // about the message. Note that in case of an error during deliver this message *will be acked*. + // For example if the session corresponding to the message is not found the message is still acked to free + // up the broker but the flow will error. + currentState = currentState.copy( + pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler + ) + // Check whether we have a session corresponding to the message. + val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId] + if (existingSession == null) { + freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId)) + } else { + val payload = event.sessionMessage.payload + // Dispatch based on what kind of message it is. + val _exhaustive = when (payload) { + is ConfirmSessionMessage -> confirmMessageTransition(existingSession, payload) + is DataSessionMessage -> dataMessageTransition(existingSession, payload) + is ErrorSessionMessage -> errorMessageTransition(existingSession, payload) + is RejectSessionMessage -> rejectMessageTransition(existingSession, payload) + is EndSessionMessage -> endMessageTransition() + } + } + if (!isErrored()) { + persistCheckpoint() + } + // Schedule a DoRemainingWork to check whether the flow needs to be woken up. + actions.add(Action.ScheduleEvent(Event.DoRemainingWork)) + FlowContinuation.ProcessEvents + } + } + + private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) { + // We received a confirmation message. The corresponding session state must be Initiating. + when (sessionState) { + is SessionState.Initiating -> { + // Create the new session state that is now Initiated. + val initiatedSession = SessionState.Initiated( + peerParty = event.sender, + peerFlowInfo = message.initiatedFlowInfo, + receivedMessages = emptyList(), + initiatedState = InitiatedSessionState.Live(message.initiatedSessionId), + errors = emptyList() + ) + val newCheckpoint = currentState.checkpoint.copy( + sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession) + ) + // Send messages that were buffered pending confirmation of session. + val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) -> + val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage) + Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId) + } + actions.addAll(sendActions) + currentState = currentState.copy(checkpoint = newCheckpoint) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) { + // We received a data message. The corresponding session must be Initiated. + return when (sessionState) { + is SessionState.Initiated -> { + // Buffer the message in the session's receivedMessages buffer. + val newSessionState = sessionState.copy( + receivedMessages = sessionState.receivedMessages + message + ) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState) + ) + ) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) { + val exception: Throwable = if (payload.flowException == null) { + UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId) + } else { + payload.flowException.originalErrorId = payload.errorId + payload.flowException + } + + return when (sessionState) { + is SessionState.Initiated -> { + val checkpoint = currentState.checkpoint + val sessionId = event.sessionMessage.recipientSessionId + val flowError = FlowError(payload.errorId, exception) + val newSessionState = sessionState.copy(errors = sessionState.errors + flowError) + currentState = currentState.copy( + checkpoint = checkpoint.copy( + sessions = checkpoint.sessions + (sessionId to newSessionState) + ) + ) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.rejectMessageTransition(sessionState: SessionState, payload: RejectSessionMessage) { + val exception = UnexpectedFlowEndException(payload.message, cause = null, originalErrorId = payload.errorId) + return when (sessionState) { + is SessionState.Initiating -> { + if (sessionState.rejectionError != null) { + // Double reject + freshErrorTransition(UnexpectedEventInState()) + } else { + val checkpoint = currentState.checkpoint + val sessionId = event.sessionMessage.recipientSessionId + val flowError = FlowError(payload.errorId, exception) + currentState = currentState.copy( + checkpoint = checkpoint.copy( + sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError)) + ) + ) + } + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.persistCheckpoint() { + // We persist the message as soon as it arrives. + actions.addAll(arrayOf( + Action.CreateTransaction, + Action.PersistCheckpoint(context.id, currentState.checkpoint), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers) + )) + currentState = currentState.copy( + pendingDeduplicationHandlers = emptyList(), + isAnyCheckpointPersisted = true + ) + } + + private fun TransitionBuilder.endMessageTransition() { + val sessionId = event.sessionMessage.recipientSessionId + val sessions = currentState.checkpoint.sessions + val sessionState = sessions[sessionId] + if (sessionState == null) { + return freshErrorTransition(CannotFindSessionException(sessionId)) + } + when (sessionState) { + is SessionState.Initiated -> { + val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = sessions + (sessionId to newSessionState) + ) + ) + } + else -> { + freshErrorTransition(UnexpectedEventInState()) + } + } + } + +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt new file mode 100644 index 0000000000..53214c71fb --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt @@ -0,0 +1,37 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.node.services.statemachine.* + +/** + * This transition checks the current state of the flow and determines whether anything needs to be done. + */ +class DoRemainingWorkTransition( + override val context: TransitionContext, + override val startingState: StateMachineState +) : Transition { + override fun transition(): TransitionResult { + val checkpoint = startingState.checkpoint + // If the flow is removed or has been resumed don't do work. + if (startingState.isFlowResumed || startingState.isRemoved) { + return TransitionResult(startingState) + } + // Check whether the flow is errored + return when (checkpoint.errorState) { + is ErrorState.Clean -> cleanTransition() + is ErrorState.Errored -> erroredTransition(checkpoint.errorState) + } + } + + // If the flow is clean check the FlowState + private fun cleanTransition(): TransitionResult { + val checkpoint = startingState.checkpoint + return when (checkpoint.flowState) { + is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition() + is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition() + } + } + + private fun erroredTransition(errorState: ErrorState.Errored): TransitionResult { + return ErrorFlowTransition(context, startingState, errorState).transition() + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt new file mode 100644 index 0000000000..660f7c6574 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -0,0 +1,124 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowException +import net.corda.node.services.statemachine.* + +/** + * This transition defines what should happen when a flow has errored. + * + * In general there are two flow-level error conditions: + * + * - Internal exceptions. These may arise due to problems in the flow framework or errors during state machine + * transitions e.g. network or database failure. + * - User-raised exceptions. These are exceptions that are (re)raised in user code, allowing the user to catch them. + * These may come from illegal flow API calls, and FlowExceptions or other counterparty failures that are re-raised + * when the flow tries to use the corresponding sessions. + * + * Both internal exceptions and uncaught user-raised exceptions cause the flow to be errored. This flags the flow as + * unable to be resumed. When a flow is in this state an external source (e.g. Flow hospital) may decide to + * + * 1. Retry it (not implemented yet). This throws away the errored state and re-tries from the last clean checkpoint. + * 2. Start error propagation. This seals the flow as errored permanently and propagates the associated error(s) to + * all live sessions. This causes these sessions to errored on the other side, which may in turn cause the + * counter-flows themselves to errored. + * + * See [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor] for how to detect flow errors. + * + * Note that in general we handle multiple errors at a time as several error conditions may arise at the same time and + * new errors may arise while the flow is in the errored state already. + */ +class ErrorFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + private val errorState: ErrorState.Errored +) : Transition { + override fun transition(): TransitionResult { + val allErrors: List = errorState.errors + val remainingErrorsToPropagate: List = allErrors.subList(errorState.propagatedIndex, allErrors.size) + val errorMessages: List = remainingErrorsToPropagate.map(this::createErrorMessageFromError) + + return builder { + // If we're errored and propagating do the actual propagation and update the index. + if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) { + val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages) + val newCheckpoint = startingState.checkpoint.copy( + errorState = errorState.copy(propagatedIndex = allErrors.size), + sessions = newSessions + ) + currentState = currentState.copy(checkpoint = newCheckpoint) + actions.add(Action.PropagateErrors(errorMessages, initiatedSessions)) + } + + // If we're errored but not propagating keep processing events. + if (remainingErrorsToPropagate.isNotEmpty() && !errorState.propagating) { + return@builder FlowContinuation.ProcessEvents + } + + // If we haven't been removed yet remove the flow. + if (!currentState.isRemoved) { + actions.add(Action.CreateTransaction) + if (currentState.isAnyCheckpointPersisted) { + actions.add(Action.RemoveCheckpoint(context.id)) + } + actions.addAll(arrayOf( + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), + Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) + )) + + currentState = currentState.copy( + pendingDeduplicationHandlers = emptyList(), + isRemoved = true + ) + + val removalReason = FlowRemovalReason.ErrorFinish(allErrors) + actions.add(Action.RemoveFlow(context.id, removalReason, currentState)) + FlowContinuation.Abort + } else { + // Otherwise keep processing events. This branch happens when there are some outstanding initiating + // sessions that prevent the removal of the flow. + FlowContinuation.ProcessEvents + } + } + } + + private fun createErrorMessageFromError(error: FlowError): ErrorSessionMessage { + val exception = error.exception + // If the exception doesn't contain an originalErrorId that means it's a fresh FlowException that should + // propagate to the neighbouring flows. If it has the ID filled in that means it's a rethrown FlowException and + // shouldn't be propagated. + return if (exception is FlowException && exception.originalErrorId == null) { + ErrorSessionMessage(flowException = exception, errorId = error.errorId) + } else { + ErrorSessionMessage(flowException = null, errorId = error.errorId) + } + } + + // Buffer error messages in Initiating sessions, return the initialised ones. + private fun bufferErrorMessagesInInitiatingSessions( + sessions: Map, + errorMessages: List + ): Pair, Map> { + val newSessions = sessions.mapValues { (sourceSessionId, sessionState) -> + if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { + // *prepend* the error messages in order to error the other sessions ASAP. The other messages will + // be delivered all the same, they just won't trigger flow resumption because of dirtiness. + val errorMessagesWithDeduplication = errorMessages.map { + DeduplicationId.createForError(it.errorId, sourceSessionId) to it + } + sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) + } else { + sessionState + } + } + val initiatedSessions = sessions.values.mapNotNull { session -> + if (session is SessionState.Initiated && session.errors.isEmpty()) { + session + } else { + null + } + } + return Pair(initiatedSessions, newSessions) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt new file mode 100644 index 0000000000..55c7458d00 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -0,0 +1,410 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowSession +import net.corda.core.flows.UnexpectedFlowEndException +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.toNonEmptySet +import net.corda.node.services.statemachine.* + +/** + * This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request + * is persisted (unless checkpoint was skipped) and the user-space DB transaction is commited. + * + * Before this transition we either did a checkpoint or the checkpoint was restored from the database. + */ +class StartedFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val started: FlowState.Started +) : Transition { + override fun transition(): TransitionResult { + val flowIORequest = started.flowIORequest + val checkpoint = startingState.checkpoint + val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint) + if (errorsToThrow.isNotEmpty()) { + return TransitionResult( + newState = startingState.copy(isFlowResumed = true), + // throw the first exception. TODO should this aggregate all of them somehow? + actions = listOf(Action.CreateTransaction), + continuation = FlowContinuation.Throw(errorsToThrow[0]) + ) + } + return when (flowIORequest) { + is FlowIORequest.Send -> sendTransition(flowIORequest) + is FlowIORequest.Receive -> receiveTransition(flowIORequest) + is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest) + is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest) + is FlowIORequest.Sleep -> sleepTransition(flowIORequest) + is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest) + is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition() + is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest) + } + } + + private fun waitForSessionConfirmationsTransition(): TransitionResult { + return builder { + if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(Unit) + } + } + } + + private fun getFlowInfoTransition(flowIORequest: FlowIORequest.GetFlowInfo): TransitionResult { + val sessionIdToSession = LinkedHashMap() + for (session in flowIORequest.sessions) { + sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session + } + return builder { + // Initialise uninitialised sessions in order to receive the associated FlowInfo. Some or all sessions may + // not be initialised yet. + sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys) + val flowInfoMap = getFlowInfoFromSessions(sessionIdToSession) + if (flowInfoMap == null) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(flowInfoMap) + } + } + } + + private fun TransitionBuilder.getFlowInfoFromSessions(sessionIdToSession: Map): Map? { + val checkpoint = currentState.checkpoint + val resultMap = LinkedHashMap() + for ((sessionId, session) in sessionIdToSession) { + val sessionState = checkpoint.sessions[sessionId] + if (sessionState is SessionState.Initiated) { + resultMap[session] = sessionState.peerFlowInfo + } else { + return null + } + } + return resultMap + } + + private fun sleepTransition(flowIORequest: FlowIORequest.Sleep): TransitionResult { + return builder { + actions.add(Action.SleepUntil(flowIORequest.wakeUpAfter)) + resumeFlowLogic(Unit) + } + } + + private fun waitForLedgerCommitTransition(flowIORequest: FlowIORequest.WaitForLedgerCommit): TransitionResult { + return if (!startingState.isTransactionTracked) { + TransitionResult( + newState = startingState.copy(isTransactionTracked = true), + actions = listOf( + Action.CreateTransaction, + Action.TrackTransaction(flowIORequest.hash), + Action.CommitTransaction + ) + ) + } else { + TransitionResult(startingState) + } + } + + private fun sendAndReceiveTransition(flowIORequest: FlowIORequest.SendAndReceive): TransitionResult { + val sessionIdToMessage = LinkedHashMap>() + val sessionIdToSession = LinkedHashMap() + for ((session, message) in flowIORequest.sessionToMessage) { + val sessionId = (session as FlowSessionImpl).sourceSessionId + sessionIdToMessage[sessionId] = message + sessionIdToSession[sessionId] = session + } + return builder { + sendToSessionsTransition(sessionIdToMessage) + if (isErrored()) { + FlowContinuation.ProcessEvents + } else { + val receivedMap = receiveFromSessionsTransition(sessionIdToSession) + if (receivedMap == null) { + // We don't yet have the messages, change the suspension to be on Receive + val newIoRequest = FlowIORequest.Receive(flowIORequest.sessionToMessage.keys.toNonEmptySet()) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + flowState = FlowState.Started(newIoRequest, started.frozenFiber) + ) + ) + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(receivedMap) + } + } + } + } + + private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult { + return builder { + val sessionIdToSession = LinkedHashMap() + for (session in flowIORequest.sessions) { + sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session + } + // send initialises to uninitialised sessions + sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys) + val receivedMap = receiveFromSessionsTransition(sessionIdToSession) + if (receivedMap == null) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(receivedMap) + } + } + } + + private fun TransitionBuilder.receiveFromSessionsTransition( + sourceSessionIdToSessionMap: Map + ): Map>? { + val checkpoint = currentState.checkpoint + val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null + val resultMap = LinkedHashMap>() + for ((sessionId, message) in pollResult.messages) { + val session = sourceSessionIdToSessionMap[sessionId]!! + resultMap[session] = message + } + currentState = currentState.copy( + checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap) + ) + return resultMap + } + + data class PollResult( + val messages: Map>, + val newSessionMap: SessionMap + ) + private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set): PollResult? { + val newSessionMessages = LinkedHashMap(sessions) + val resultMessages = LinkedHashMap>() + var someNotFound = false + for (sessionId in sessionIds) { + val sessionState = sessions[sessionId] + when (sessionState) { + is SessionState.Initiated -> { + val messages = sessionState.receivedMessages + if (messages.isEmpty()) { + someNotFound = true + } else { + newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) + resultMessages[sessionId] = messages[0].payload + } + } + else -> { + someNotFound = true + } + } + } + return if (someNotFound) { + return null + } else { + PollResult(resultMessages, newSessionMessages) + } + } + + private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set) { + val checkpoint = startingState.checkpoint + val newSessions = LinkedHashMap(checkpoint.sessions) + var index = 0 + for (sourceSessionId in sourceSessions) { + val sessionState = checkpoint.sessions[sourceSessionId] + if (sessionState == null) { + return freshErrorTransition(CannotFindSessionException(sourceSessionId)) + } + if (sessionState !is SessionState.Uninitiated) { + continue + } + val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) + val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null) + actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId)) + newSessions[sourceSessionId] = SessionState.Initiating( + bufferedMessages = emptyList(), + rejectionError = null + ) + } + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + } + + private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult { + return builder { + val sessionIdToMessage = flowIORequest.sessionToMessage.mapKeys { + sessionToSessionId(it.key) + } + sendToSessionsTransition(sessionIdToMessage) + if (isErrored()) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(Unit) + } + } + } + + private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map>) { + val checkpoint = startingState.checkpoint + val newSessions = LinkedHashMap(checkpoint.sessions) + var index = 0 + for ((sourceSessionId, message) in sourceSessionIdToMessage) { + val existingSessionState = checkpoint.sessions[sourceSessionId] + if (existingSessionState == null) { + return freshErrorTransition(CannotFindSessionException(sourceSessionId)) + } else { + val sessionMessage = DataSessionMessage(message) + val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) + val _exhaustive = when (existingSessionState) { + is SessionState.Uninitiated -> { + val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message) + actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId)) + newSessions[sourceSessionId] = SessionState.Initiating( + bufferedMessages = emptyList(), + rejectionError = null + ) + Unit + } + is SessionState.Initiating -> { + // We're initiating this session, buffer the message + val newBufferedMessages = existingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage) + newSessions[sourceSessionId] = existingSessionState.copy(bufferedMessages = newBufferedMessages) + } + is SessionState.Initiated -> { + when (existingSessionState.initiatedState) { + is InitiatedSessionState.Live -> { + val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId + val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage) + actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId)) + Unit + } + InitiatedSessionState.Ended -> { + return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId")) + } + } + } + } + } + + } + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + } + + private fun sessionToSessionId(session: FlowSession): SessionId { + return (session as FlowSessionImpl).sourceSessionId + } + + private fun collectErroredSessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.flatMap { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Uninitiated -> emptyList() + is SessionState.Initiating -> { + if (sessionState.rejectionError == null) { + emptyList() + } else { + listOf(sessionState.rejectionError.exception) + } + } + is SessionState.Initiated -> sessionState.errors.map(FlowError::exception) + } + } + } + + private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List { + return checkpoint.sessions.values.mapNotNull { sessionState -> + (sessionState as? SessionState.Initiating)?.rejectionError?.exception + } + } + + private fun collectEndedSessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.mapNotNull { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Initiated -> { + if (sessionState.initiatedState is InitiatedSessionState.Ended) { + UnexpectedFlowEndException( + "Tried to access ended session $sessionId", + cause = null, + originalErrorId = context.secureRandom.nextLong() + ) + } else { + null + } + } + else -> null + } + } + } + + private fun collectEndedEmptySessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.mapNotNull { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Initiated -> { + if (sessionState.initiatedState is InitiatedSessionState.Ended && + sessionState.receivedMessages.isEmpty()) { + UnexpectedFlowEndException( + "Tried to access ended session $sessionId with empty buffer", + cause = null, + originalErrorId = context.secureRandom.nextLong() + ) + } else { + null + } + } + else -> null + } + } + } + + private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List { + return when (flowIORequest) { + is FlowIORequest.Send -> { + val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.Receive -> { + val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.SendAndReceive -> { + val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.WaitForLedgerCommit -> { + collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint) + } + is FlowIORequest.GetFlowInfo -> { + collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint) + } + is FlowIORequest.Sleep -> { + emptyList() + } + is FlowIORequest.WaitForSessionConfirmations -> { + collectErroredInitiatingSessionErrors(checkpoint) + } + is FlowIORequest.ExecuteAsyncOperation<*> -> { + emptyList() + } + } + } + + private fun createInitialSessionMessage( + initiatingSubFlow: SubFlow.Initiating, + sourceSessionId: SessionId, + payload: SerializedBytes? + ): InitialSessionMessage { + return InitialSessionMessage( + initiatorSessionId = sourceSessionId, + // We add additional entropy to add to the initiated side's deduplication seed. + initiationEntropy = context.secureRandom.nextLong(), + initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name, + flowVersion = initiatingSubFlow.flowInfo.flowVersion, + appName = initiatingSubFlow.flowInfo.appName, + firstPayload = payload + ) + } + + private fun executeAsyncOperation(flowIORequest: FlowIORequest.ExecuteAsyncOperation<*>): TransitionResult { + return builder { + actions.add(Action.ExecuteAsyncOperation(flowIORequest.operation)) + FlowContinuation.ProcessEvents + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt new file mode 100644 index 0000000000..29e29746cc --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt @@ -0,0 +1,30 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.* +import net.corda.node.services.statemachine.* +import java.security.SecureRandom + +/** + * @property eventQueueSize the size of a flow's event queue. If the queue gets full the thread scheduling the event + * will block. An example scenario would be if the flow is waiting for a lot of messages at once, but is slow at + * processing each. + */ +data class StateMachineConfiguration( + val eventQueueSize: Int +) { + companion object { + val default = StateMachineConfiguration( + eventQueueSize = 16 + ) + } +} + +class StateMachine( + val id: StateMachineRunId, + val configuration: StateMachineConfiguration, + val secureRandom: SecureRandom +) { + fun transition(event: Event, state: StateMachineState): TransitionResult { + return TopLevelTransition(TransitionContext(id, configuration, secureRandom), state, event).transition() + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt new file mode 100644 index 0000000000..684c74b1ab --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -0,0 +1,243 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.InitiatingFlow +import net.corda.core.internal.FlowIORequest +import net.corda.core.utilities.Try +import net.corda.node.services.statemachine.* + +/** + * This is the top level event-handling transition function capable of handling any [Event]. + * + * It is a *pure* function taking a state machine state and an event, returning the next state along with a list of IO + * actions to execute. + */ +class TopLevelTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val event: Event +) : Transition { + override fun transition(): TransitionResult { + return when (event) { + is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition() + is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition() + is Event.Error -> errorTransition(event) + is Event.TransactionCommitted -> transactionCommittedTransition(event) + is Event.SoftShutdown -> softShutdownTransition() + is Event.StartErrorPropagation -> startErrorPropagationTransition() + is Event.EnterSubFlow -> enterSubFlowTransition(event) + is Event.LeaveSubFlow -> leaveSubFlowTransition() + is Event.Suspend -> suspendTransition(event) + is Event.FlowFinish -> flowFinishTransition(event) + is Event.InitiateFlow -> initiateFlowTransition(event) + is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event) + } + } + + private fun errorTransition(event: Event.Error): TransitionResult { + return builder { + freshErrorTransition(event.exception) + FlowContinuation.ProcessEvents + } + } + + private fun transactionCommittedTransition(event: Event.TransactionCommitted): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + if (currentState.isTransactionTracked && + checkpoint.flowState is FlowState.Started && + checkpoint.flowState.flowIORequest is FlowIORequest.WaitForLedgerCommit && + checkpoint.flowState.flowIORequest.hash == event.transaction.id) { + currentState = currentState.copy(isTransactionTracked = false) + if (isErrored()) { + return@builder FlowContinuation.ProcessEvents + } + resumeFlowLogic(event.transaction) + } else { + freshErrorTransition(UnexpectedEventInState()) + FlowContinuation.ProcessEvents + } + } + } + + private fun softShutdownTransition(): TransitionResult { + val lastState = startingState.copy(isRemoved = true) + return TransitionResult( + newState = lastState, + actions = listOf( + Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys), + Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState) + ), + continuation = FlowContinuation.Abort + ) + } + + private fun startErrorPropagationTransition(): TransitionResult { + return builder { + val errorState = currentState.checkpoint.errorState + when (errorState) { + ErrorState.Clean -> freshErrorTransition(UnexpectedEventInState()) + is ErrorState.Errored -> { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + errorState = errorState.copy(propagating = true) + ) + ) + actions.add(Action.ScheduleEvent(Event.DoRemainingWork)) + } + } + FlowContinuation.ProcessEvents + } + } + + private fun enterSubFlowTransition(event: Event.EnterSubFlow): TransitionResult { + return builder { + val subFlow = SubFlow.create(event.subFlowClass) + when (subFlow) { + is Try.Success -> { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value + ) + ) + } + is Try.Failure -> { + freshErrorTransition(subFlow.exception) + } + } + FlowContinuation.ProcessEvents + } + } + + private fun leaveSubFlowTransition(): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + if (checkpoint.subFlowStack.isEmpty()) { + freshErrorTransition(UnexpectedEventInState()) + } else { + currentState = currentState.copy( + checkpoint = checkpoint.copy( + subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList() + ) + ) + } + FlowContinuation.ProcessEvents + } + } + + private fun suspendTransition(event: Event.Suspend): TransitionResult { + return builder { + val newCheckpoint = currentState.checkpoint.copy( + flowState = FlowState.Started(event.ioRequest, event.fiber), + numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1 + ) + if (event.maySkipCheckpoint) { + actions.addAll(arrayOf( + Action.CommitTransaction, + Action.ScheduleEvent(Event.DoRemainingWork) + )) + currentState = currentState.copy( + checkpoint = newCheckpoint, + isFlowResumed = false + ) + } else { + actions.addAll(arrayOf( + Action.PersistCheckpoint(context.id, newCheckpoint), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), + Action.ScheduleEvent(Event.DoRemainingWork) + )) + currentState = currentState.copy( + checkpoint = newCheckpoint, + pendingDeduplicationHandlers = emptyList(), + isFlowResumed = false, + isAnyCheckpointPersisted = true + ) + } + FlowContinuation.ProcessEvents + } + } + + private fun flowFinishTransition(event: Event.FlowFinish): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + when (checkpoint.errorState) { + ErrorState.Clean -> { + val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + currentState = currentState.copy( + checkpoint = checkpoint.copy( + numberOfSuspends = checkpoint.numberOfSuspends + 1 + ), + pendingDeduplicationHandlers = emptyList(), + isFlowResumed = false, + isRemoved = true + ) + val allSourceSessionIds = checkpoint.sessions.keys + if (currentState.isAnyCheckpointPersisted) { + actions.add(Action.RemoveCheckpoint(context.id)) + } + actions.addAll(arrayOf( + Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(pendingDeduplicationHandlers), + Action.RemoveSessionBindings(allSourceSessionIds), + Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState) + )) + sendEndMessages() + // Resume to end fiber + FlowContinuation.Resume(null) + } + is ErrorState.Errored -> { + currentState = currentState.copy(isFlowResumed = false) + actions.add(Action.RollbackTransaction) + FlowContinuation.ProcessEvents + } + } + } + } + + private fun TransitionBuilder.sendEndMessages() { + val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state -> + if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { + val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) + val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index) + Action.SendExisting(state.peerParty, message, deduplicationId) + } else { + null + } + }.filterNotNull() + actions.addAll(sendEndMessageActions) + } + + private fun initiateFlowTransition(event: Event.InitiateFlow): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + val initiatingSubFlow = getClosestAncestorInitiatingSubFlow(checkpoint) + if (initiatingSubFlow == null) { + freshErrorTransition(IllegalStateException("Tried to initiate in a flow not annotated with @${InitiatingFlow::class.java.simpleName}")) + return@builder FlowContinuation.ProcessEvents + } + val sourceSessionId = SessionId.createRandom(context.secureRandom) + val sessionImpl = FlowSessionImpl(event.party, sourceSessionId) + val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow)) + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + actions.add(Action.AddSessionBinding(context.id, sourceSessionId)) + FlowContinuation.Resume(sessionImpl) + } + } + + private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? { + for (subFlow in checkpoint.subFlowStack.asReversed()) { + if (subFlow is SubFlow.Initiating) { + return subFlow + } + } + return null + } + + private fun asyncOperationCompletionTransition(event: Event.AsyncOperationCompletion): TransitionResult { + return builder { + resumeFlowLogic(event.returnValue) + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt new file mode 100644 index 0000000000..20441dbab3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt @@ -0,0 +1,32 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.StateMachineState +import java.security.SecureRandom + +/** + * An interface used to separate out different parts of the state machine transition function. + */ +interface Transition { + /** The context of the transition. */ + val context: TransitionContext + /** The state the transition is starting in. */ + val startingState: StateMachineState + /** The (almost) pure transition function. The only side-effect we allow is random number generation. */ + fun transition(): TransitionResult + + /** + * A helper + */ + fun builder(build: TransitionBuilder.() -> FlowContinuation): TransitionResult { + val builder = TransitionBuilder(context, startingState) + val continuation = build(builder) + return TransitionResult(builder.currentState, builder.actions, continuation) + } +} + +class TransitionContext( + val id: StateMachineRunId, + val configuration: StateMachineConfiguration, + val secureRandom: SecureRandom +) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt new file mode 100644 index 0000000000..01715adde6 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt @@ -0,0 +1,74 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.IdentifiableException +import net.corda.node.services.statemachine.* + +// This is a file defining some common utilities for creating state machine transitions. + +/** + * A builder that helps creating [Transition]s. This allows for a more imperative style of specifying the transition. + */ +class TransitionBuilder(val context: TransitionContext, initialState: StateMachineState) { + /** The current state machine state of the builder */ + var currentState = initialState + /** The list of actions to execute */ + val actions = ArrayList() + + /** Check if [currentState] state is errored */ + fun isErrored(): Boolean = currentState.checkpoint.errorState is ErrorState.Errored + + /** + * Transition the builder into an error state because of a fresh error that happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun freshErrorTransition(error: Throwable) { + val flowError = FlowError( + errorId = (error as? IdentifiableException)?.errorId ?: context.secureRandom.nextLong(), + exception = error + ) + errorTransition(flowError) + } + + /** + * Transition the builder into an error state because of a list of errors that happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun errorsTransition(errors: List) { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + errorState = currentState.checkpoint.errorState.addErrors(errors) + ), + isFlowResumed = false + ) + actions.clear() + actions.addAll(arrayOf( + Action.RollbackTransaction, + Action.ScheduleEvent(Event.DoRemainingWork) + )) + } + + /** + * Transition the builder into an error state because of a non-fresh error has happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun errorTransition(error: FlowError) { + errorsTransition(listOf(error)) + } + + fun resumeFlowLogic(result: Any?): FlowContinuation { + actions.add(Action.CreateTransaction) + currentState = currentState.copy(isFlowResumed = true) + return FlowContinuation.Resume(result) + } +} + + + +class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId") +class UnexpectedEventInState : IllegalStateException("Unexpected event") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt new file mode 100644 index 0000000000..43e934634b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt @@ -0,0 +1,46 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.node.services.statemachine.Action +import net.corda.node.services.statemachine.StateMachineState + +/** + * A datastructure capturing the intended new state of the flow, the actions to be executed as part of the transition + * and a [FlowContinuation]. + * + * Read this datastructure as an instruction to the state machine executor: + * "Transition to [newState] *if* [actions] execute cleanly. If so, use [continuation] to decide what to do next. If + * there was an error it's up to you what to do". + * Also see [net.corda.node.services.statemachine.TransitionExecutorImpl] on how this is interpreted. + */ +data class TransitionResult( + val newState: StateMachineState, + val actions: List = emptyList(), + val continuation: FlowContinuation = FlowContinuation.ProcessEvents +) + +/** + * A datastructure describing what to do after a transition has succeeded. + */ +sealed class FlowContinuation { + /** + * Return to user code with the supplied [result]. + */ + data class Resume(val result: Any?) : FlowContinuation() { + override fun toString() = "Resume(result=${result?.javaClass})" + } + + /** + * Throw an exception [throwable] in user code. + */ + data class Throw(val throwable: Throwable) : FlowContinuation() + + /** + * Keep processing pending events. + */ + object ProcessEvents : FlowContinuation() { override fun toString() = "ProcessEvents" } + + /** + * Immediately abort the flow. Note that this does not imply an error condition. + */ + object Abort : FlowContinuation() { override fun toString() = "Abort" } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt new file mode 100644 index 0000000000..bd9f30c65c --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -0,0 +1,80 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowInfo +import net.corda.node.services.statemachine.* + +/** + * This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and + * initialises the initiated session in case the flow is an initiated one. + */ +class UnstartedFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val unstarted: FlowState.Unstarted +) : Transition { + override fun transition(): TransitionResult { + return builder { + if (!currentState.isAnyCheckpointPersisted && !currentState.isStartIdempotent) { + createInitialCheckpoint() + } + + actions.add(Action.SignalFlowHasStarted(context.id)) + + if (unstarted.flowStart is FlowStart.Initiated) { + initialiseInitiatedSession(unstarted.flowStart) + } + + currentState = currentState.copy(isFlowResumed = true) + actions.add(Action.CreateTransaction) + FlowContinuation.Resume(null) + } + } + + // Initialise initiated session, store initial payload, send confirmation back. + private fun TransitionBuilder.initialiseInitiatedSession(flowStart: FlowStart.Initiated) { + val initiatingMessage = flowStart.initiatingMessage + val initiatedState = SessionState.Initiated( + peerParty = flowStart.peerSession.counterparty, + initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId), + peerFlowInfo = FlowInfo( + flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion, + appName = initiatingMessage.appName + ), + receivedMessages = if (initiatingMessage.firstPayload == null) { + emptyList() + } else { + listOf(DataSessionMessage(initiatingMessage.firstPayload)) + }, + errors = emptyList() + ) + val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) + val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = mapOf(flowStart.initiatedSessionId to initiatedState) + ) + ) + actions.add( + Action.SendExisting( + flowStart.peerSession.counterparty, + sessionMessage, + DeduplicationId.createForNormal(currentState.checkpoint, 0) + ) + ) + } + + // Create initial checkpoint and acknowledge triggering messages. + private fun TransitionBuilder.createInitialCheckpoint() { + actions.addAll(arrayOf( + Action.CreateTransaction, + Action.PersistCheckpoint(context.id, currentState.checkpoint), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers) + )) + currentState = currentState.copy( + pendingDeduplicationHandlers = emptyList(), + isAnyCheckpointPersisted = true + ) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt index b57ac3a45c..a1a8b61270 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt @@ -7,6 +7,7 @@ import net.corda.core.node.services.VaultService import net.corda.core.utilities.* import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager +import net.corda.nodeapi.internal.persistence.contextDatabase import java.util.* class VaultSoftLockManager private constructor(private val vault: VaultService) { @@ -48,11 +49,15 @@ class VaultSoftLockManager private constructor(private val vault: VaultService) private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet) { log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" } - vault.softLockReserve(flowId, stateRefs) + contextDatabase.transaction { + vault.softLockReserve(flowId, stateRefs) + } } private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) { log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" } - vault.softLockRelease(flowId) + contextDatabase.transaction { + vault.softLockRelease(flowId) + } } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt b/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt new file mode 100644 index 0000000000..3f0d73d2ed --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt @@ -0,0 +1,144 @@ +package net.corda.node.utilities + +import java.lang.reflect.Method +import java.lang.reflect.Modifier +import java.lang.reflect.Type +import java.time.Instant + +/** + * A tree describing the diff between two objects. + * + * For example: + * data class A(val field1: Int, val field2: String, val field3: Unit) + * fun main(args: Array) { + * val someA = A(1, "hello", Unit) + * val someOtherA = A(2, "bello", Unit) + * println(ObjectDiffer.diff(someA, someOtherA)) + * } + * + * Will give back Step(branches=[(field1, Last(a=1, b=2)), (field2, Last(a=hello, b=bello))]) + */ +sealed class DiffTree { + /** + * Describes a "step" from the object root. It contains a list of field-subtree pairs. + */ + data class Step(val branches: List>) : DiffTree() + + /** + * Describes the leaf of the diff. This is either where the diffing was cutoff (e.g. primitives) or where it failed. + */ + data class Last(val a: Any?, val b: Any?) : DiffTree() + + /** + * Flattens the [DiffTree] into a list of [DiffPath]s + */ + fun toPaths(): List { + return when (this) { + is Step -> branches.flatMap { (step, tree) -> tree.toPaths().map { it.copy(path = listOf(step) + it.path) } } + is Last -> listOf(DiffPath(emptyList(), a, b)) + } + } +} + +/** + * A diff focused on a single [DiffTree.Last] diff, including the path leading there. + */ +data class DiffPath( + val path: List, + val a: Any?, + val b: Any? +) { + override fun toString(): String { + return "${path.joinToString(".")}: \n $a\n $b\n" + } +} + +/** + * This is a very simple differ used to diff objects of any kind, to be used for diagnostic. + */ +object ObjectDiffer { + fun diff(a: Any?, b: Any?): DiffTree? { + if (a == null || b == null) { + if (a == b) { + return null + } else { + return DiffTree.Last(a, b) + } + } + if (a != b) { + if (a.javaClass.isPrimitive || a.javaClass in diffCutoffClasses) { + return DiffTree.Last(a, b) + } + // TODO deduplicate this code + if (a is Map<*, *> && b is Map<*, *>) { + val allKeys = a.keys + b.keys + val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } } + if (branches.isEmpty()) { + return null + } else { + return DiffTree.Step(branches) + } + } + if (a is java.util.Map<*, *> && b is java.util.Map<*, *>) { + val allKeys = a.keySet() + b.keySet() + val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } } + if (branches.isEmpty()) { + return null + } else { + return DiffTree.Step(branches) + } + } + val aFields = getFieldFoci(a) + val bFields = getFieldFoci(b) + try { + if (aFields != bFields) { + return DiffTree.Last(a, b) + } else { + // TODO need to account for cases where the fields don't match up (different subclasses) + val branches = aFields.map { field -> diff(field.get(a), field.get(b))?.let { field.name to it } }.filterNotNull() + if (branches.isEmpty()) { + return DiffTree.Last(a, b) + } else { + return DiffTree.Step(branches) + } + } + } catch (throwable: Exception) { + Exception("Error while diffing $a with $b", throwable).printStackTrace(System.out) + return DiffTree.Last(a, b) + } + } else { + return null + } + } + + // List of types to cutoff the diffing at. + private val diffCutoffClasses: Set> = setOf( + String::class.java, + Class::class.java, + Instant::class.java + ) + + // A type capturing the accessor to a field. This is a separate abstraction to simple reflection as we identify + // getX() and isX() calls as fields as well. + private data class FieldFocus(val name: String, val type: Type, val getter: Method) { + fun get(obj: Any): Any? { + return getter.invoke(obj) + } + } + + private fun getFieldFoci(obj: Any) : List { + val foci = ArrayList() + for (method in obj.javaClass.declaredMethods) { + if (Modifier.isStatic(method.modifiers)) { + continue + } + if (method.name.startsWith("get") && method.name.length > 3 && method.parameterCount == 0) { + val fieldName = method.name[3].toLowerCase() + method.name.substring(4) + foci.add(FieldFocus(fieldName, method.returnType, method)) + } else if (method.name.startsWith("is") && method.parameterCount == 0) { + foci.add(FieldFocus(method.name, method.returnType, method)) + } + } + return foci + } +} diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index 7998a9df32..d250ddf7b1 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -49,10 +49,10 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var finalDelivery: Message? = null - node2.network.addMessageHandler("test.topic") { msg, _ -> + node2.network.addMessageHandler("test.topic") { msg, _, _ -> node2.network.send(msg, node3.network.myAddress) } - node3.network.addMessageHandler("test.topic") { msg, _ -> + node3.network.addMessageHandler("test.topic") { msg, _, _ -> finalDelivery = msg } @@ -73,7 +73,7 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var counter = 0 - listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } } + listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } } node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock()) mockNet.runNetwork(rounds = 1) assertEquals(3, counter) @@ -89,9 +89,10 @@ class InMemoryMessagingTests { val node2 = mockNet.createNode() var received = 0 - node1.network.addMessageHandler("valid_message") { _, _ -> + node1.network.addMessageHandler("valid_message") { _, _, _ -> received++ } + val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) node2.network.send(invalidMessage, node1.network.myAddress) diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 604eca0f7a..dc7a3505e0 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -744,6 +744,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { private val database: CordaPersistence, private val delegate: WritableTransactionStorage ) : WritableTransactionStorage, SingletonSerializeAsToken() { + override fun trackTransaction(id: SecureHash): CordaFuture { + return database.transaction { + delegate.trackTransaction(id) + } + } + override fun track(): DataFeed, SignedTransaction> { return database.transaction { delegate.track() diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index b722d41f5d..db14bc1632 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -1,6 +1,5 @@ package net.corda.node.services.events -import com.google.common.util.concurrent.MoreExecutors import com.nhaarman.mockito_kotlin.* import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash @@ -22,6 +21,7 @@ import net.corda.testing.internal.doLookup import net.corda.testing.internal.rigorousMock import net.corda.testing.node.MockServices import net.corda.testing.node.TestClock +import org.junit.Ignore import org.junit.Rule import org.junit.Test import org.junit.rules.TestWatcher @@ -42,8 +42,14 @@ open class NodeSchedulerServiceTestBase { protected val testClock = TestClock(rigorousMock().also { doReturn(mark).whenever(it).instant() }) + private val database = rigorousMock().also { + doAnswer { + val block: DatabaseTransaction.() -> Any? = uncheckedCast(it.arguments[0]) + rigorousMock().block() + }.whenever(it).transaction(any()) + } protected val flowStarter = rigorousMock().also { - doReturn(openFuture>()).whenever(it).startFlow(any>(), any()) + doReturn(openFuture>()).whenever(it).startFlow(any>(), any(), any()) } private val flowsDraingMode = rigorousMock().also { doReturn(false).whenever(it).isEnabled() @@ -76,7 +82,7 @@ open class NodeSchedulerServiceTestBase { protected fun assertStarted(flowLogic: FlowLogic<*>) { // Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow: - verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any()) + verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any(), any()) } protected fun assertStarted(event: Event) = assertStarted(event.flowLogic) @@ -95,7 +101,6 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() { database, flowStarter, servicesForResolution, - serverThread = MoreExecutors.directExecutor(), flowLogicRefFactory = flowLogicRefFactory, nodeProperties = nodeProperties, drainingModePollPeriod = Duration.ofSeconds(5), @@ -209,7 +214,6 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() { db, flowStarter, servicesForResolution, - serverThread = MoreExecutors.directExecutor(), flowLogicRefFactory = flowLogicRefFactory, nodeProperties = nodeProperties, drainingModePollPeriod = Duration.ofSeconds(5), @@ -262,6 +266,7 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() { newDatabase.close() } + @Ignore("Temporarily") @Test fun `test that if schedule is updated then the flow is invoked on the correct schedule`() { val dataSourceProps = MockServices.makeTestDataSourceProperties() @@ -293,4 +298,4 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() { scheduler.join() database.close() } -} \ No newline at end of file +} diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt index 5dadb85eac..bc255c91aa 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt @@ -151,7 +151,9 @@ class ArtemisMessagingTest { createMessagingServer().start() val messagingClient = createMessagingClient(platformVersion = platformVersion) - messagingClient.addMessageHandler(TOPIC) { message, _ -> + messagingClient.addMessageHandler(TOPIC) { message, _, handle -> + database.transaction { handle.insideDatabaseTransaction() } + handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] receivedMessages.add(message) } startNodeMessagingClient() diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 7851fdc874..9bba77f66a 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -1,33 +1,40 @@ package net.corda.node.services.persistence -import com.google.common.primitives.Ints +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.api.Checkpoint -import net.corda.node.services.api.CheckpointStorage -import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.core.serialization.serialize import net.corda.node.internal.configureDatabase +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.statemachine.Checkpoint +import net.corda.node.services.statemachine.FlowStart +import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig -import net.corda.testing.internal.LogHelper +import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.SerializationEnvironmentRule -import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties +import net.corda.testing.core.TestIdentity +import net.corda.testing.internal.LogHelper import net.corda.testing.internal.rigorousMock +import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test +import kotlin.streams.toList -internal fun CheckpointStorage.checkpoints(): List { - val checkpoints = mutableListOf() - forEach { - checkpoints += it - true - } - return checkpoints +internal fun CheckpointStorage.checkpoints(): List> { + val checkpoints = getAllCheckpoints().toList() + return checkpoints.map { it.second } } class DBCheckpointStorageTests { + private companion object { + val ALICE = TestIdentity(ALICE_NAME, 70).party + } @Rule @JvmField val testSerialization = SerializationEnvironmentRule() @@ -50,9 +57,9 @@ class DBCheckpointStorageTests { @Test fun `add new checkpoint`() { - val checkpoint = newCheckpoint() + val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) @@ -65,12 +72,12 @@ class DBCheckpointStorageTests { @Test fun `remove checkpoint`() { - val checkpoint = newCheckpoint() + val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) } database.transaction { - checkpointStorage.removeCheckpoint(checkpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).isEmpty() @@ -83,12 +90,12 @@ class DBCheckpointStorageTests { @Test fun `add and remove checkpoint in single commit operate`() { - val checkpoint = newCheckpoint() - val checkpoint2 = newCheckpoint() + val (id, checkpoint) = newCheckpoint() + val (id2, checkpoint2) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) - checkpointStorage.addCheckpoint(checkpoint2) - checkpointStorage.removeCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id2, checkpoint2) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) @@ -101,16 +108,16 @@ class DBCheckpointStorageTests { @Test fun `add two checkpoints then remove first one`() { - val firstCheckpoint = newCheckpoint() + val (id, firstCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(firstCheckpoint) + checkpointStorage.addCheckpoint(id, firstCheckpoint) } - val secondCheckpoint = newCheckpoint() + val (id2, secondCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(secondCheckpoint) + checkpointStorage.addCheckpoint(id2, secondCheckpoint) } database.transaction { - checkpointStorage.removeCheckpoint(firstCheckpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) @@ -123,9 +130,9 @@ class DBCheckpointStorageTests { @Test fun `add checkpoint and then remove after 'restart'`() { - val originalCheckpoint = newCheckpoint() + val (id, originalCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(originalCheckpoint) + checkpointStorage.addCheckpoint(id, originalCheckpoint) } newCheckpointStorage() val reconstructedCheckpoint = database.transaction { @@ -135,7 +142,7 @@ class DBCheckpointStorageTests { assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) } database.transaction { - checkpointStorage.removeCheckpoint(reconstructedCheckpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).isEmpty() @@ -148,7 +155,14 @@ class DBCheckpointStorageTests { } } - private var checkpointCount = 1 - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++))) + private fun newCheckpoint(): Pair> { + val id = StateMachineRunId.createRandom() + val logic: FlowLogic<*> = object : FlowLogic() { + override fun call() {} + } + val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, "").getOrThrow() + return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + } } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 55d7073e5b..e4512c01af 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.concurrent.Semaphore import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.ContractState @@ -67,10 +68,6 @@ class FlowFrameworkTests { private lateinit var alice: Party private lateinit var bob: Party - private fun StartedNode<*>.flushSmm() { - (this.smm as StateMachineManagerImpl).executor.flush() - } - @Before fun start() { mockNet = InternalMockNetwork( @@ -109,6 +106,19 @@ class FlowFrameworkTests { assertThat(flow.lazyTime).isNotNull() } + class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor { + var thrown = false + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + if (thrown) { + delegate.executeAction(fiber, action) + } else { + thrown = true + throw exception + } + } + } + @Test fun `exception while fiber suspended`() { bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } @@ -116,16 +126,15 @@ class FlowFrameworkTests { val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl // Before the flow runs change the suspend action to throw an exception val exceptionDuringSuspend = Exception("Thrown during suspend") - fiber.actionOnSuspend = { - throw exceptionDuringSuspend - } + val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor) + fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor)) mockNet.runNetwork() assertThatThrownBy { fiber.resultFuture.getOrThrow() }.isSameAs(exceptionDuringSuspend) assertThat(aliceNode.smm.allStateMachines).isEmpty() // Make sure the fiber does actually terminate - assertThat(fiber.isTerminated).isTrue() + assertThat(fiber.state).isEqualTo(Strand.State.WAITING) } @Test @@ -148,7 +157,6 @@ class FlowFrameworkTests { aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow // Make sure the add() has finished initial processing. - bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() bobNode.dispose() // kill receiver val restoredFlow = bobNode.restartAndGetRestoredFlow() @@ -174,7 +182,6 @@ class FlowFrameworkTests { assertEquals(1, bobNode.checkpointStorage.checkpoints().size) } // Make sure the add() has finished initial processing. - bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() // Restart node and thus reload the checkpoint and resend the message with same UUID bobNode.dispose() @@ -187,7 +194,6 @@ class FlowFrameworkTests { val (firstAgain, fut1) = node2b.getSingleFlow() // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. mockNet.runNetwork() - node2b.flushSmm() fut1.getOrThrow() val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } @@ -216,6 +222,8 @@ class FlowFrameworkTests { val payload = "Hello World" aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) mockNet.runNetwork() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + charlieNode.internals.acceptableLiveFiberCountOnStop = 1 val bobFlow = bobNode.getSingleFlow().first val charlieFlow = charlieNode.getSingleFlow().first assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) @@ -234,9 +242,6 @@ class FlowFrameworkTests { aliceNode sent normalEnd to charlieNode //There's no session end from the other flows as they're manually suspended ) - - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - charlieNode.internals.acceptableLiveFiberCountOnStop = 1 } @Test @@ -338,7 +343,9 @@ class FlowFrameworkTests { mockNet.runNetwork() - assertThat(erroringFlowSteps.get()).containsExactly( + erroringFlowFuture.getOrThrow() + val flowSteps = erroringFlowSteps.get() + assertThat(flowSteps).containsExactly( Notification.createOnNext(ExceptionFlow.START_STEP), Notification.createOnError(erroringFlowFuture.get().exceptionThrown) ) @@ -378,8 +385,8 @@ class FlowFrameworkTests { assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() } - assertThat(receivingFiber.isTerminated).isTrue() - assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue() + assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING) + assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).state).isEqualTo(Strand.State.WAITING) assertThat(erroringFlowSteps.get()).containsExactly( Notification.createOnNext(ExceptionFlow.START_STEP), Notification.createOnError(erroringFlow.get().exceptionThrown) @@ -396,7 +403,7 @@ class FlowFrameworkTests { } @Test - fun `FlowException propagated in invocation chain`() { + fun `FlowException only propagated to parent`() { val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) val charlie = charlieNode.info.singleIdentity() @@ -404,9 +411,8 @@ class FlowFrameworkTests { bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) mockNet.runNetwork() - assertThatExceptionOfType(MyFlowException::class.java) + assertThatExceptionOfType(UnexpectedFlowEndException::class.java) .isThrownBy { receivingFiber.resultFuture.getOrThrow() } - .withMessage("Chain") } @Test @@ -558,10 +564,8 @@ class FlowFrameworkTests { @Test fun `customised client flow which has annotated @InitiatingFlow again`() { - val result = aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture - mockNet.runNetwork() assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { - result.getOrThrow() + aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture }.withMessageContaining(InitiatingFlow::class.java.simpleName) } @@ -635,24 +639,6 @@ class FlowFrameworkTests { assertThat(result.getOrThrow()).isEqualTo("HelloHello") } - @Test - fun `double initiateFlow throws`() { - val future = aliceNode.services.startFlow(DoubleInitiatingFlow()).resultFuture - mockNet.runNetwork() - assertThatExceptionOfType(IllegalStateException::class.java) - .isThrownBy { future.getOrThrow() } - .withMessageContaining("Attempted to initiateFlow() twice") - } - - @InitiatingFlow - private class DoubleInitiatingFlow : FlowLogic() { - @Suspendable - override fun call() { - initiateFlow(ourIdentity) - initiateFlow(ourIdentity) - } - } - //////////////////////////////////////////////////////////////////////////////////////////////////////////// //region Helpers @@ -685,7 +671,6 @@ class FlowFrameworkTests { private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) } - private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) @@ -694,7 +679,7 @@ class FlowFrameworkTests { private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { services.networkService.apply { val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) - send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address) + send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) } } @@ -720,7 +705,7 @@ class FlowFrameworkTests { } private fun Observable.toSessionTransfers(): Observable { - return filter { it.getMessage().topic == StateMachineManagerImpl.sessionTopic }.map { + return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map { val from = it.sender.id val message = it.messageData.deserialize() SessionTransfer(from, sanitise(message), it.recipients) diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/MaxTransactionSizeTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/MaxTransactionSizeTests.kt index 1e17c6b952..9195a5d468 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/MaxTransactionSizeTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/MaxTransactionSizeTests.kt @@ -5,10 +5,8 @@ import net.corda.core.crypto.SecureHash import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.InputStreamAndHash -import net.corda.core.node.ServiceHub import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.getOrThrow -import net.corda.node.services.api.StartedNodeServices import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState @@ -22,7 +20,6 @@ import net.corda.testing.node.StartedMockNode import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before -import org.junit.Ignore import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -90,12 +87,11 @@ class MaxTransactionSizeTests { assertEquals(hash1, bigFile1.sha256) SendLargeTransactionFlow(notary, bob, hash1, hash2, hash3, hash4, verify = false) } - val ex = assertFailsWith { + assertFailsWith { val future = aliceNode.startFlow(flow) mockNet.runNetwork() future.getOrThrow() } - assertThat(ex).hasMessageContaining("Counterparty flow on O=Bob Plc, L=Rome, C=IT had an internal error and has terminated") } @StartableByRPC @@ -135,4 +131,4 @@ class MaxTransactionSizeTests { otherSide.send(Unit) } } -} \ No newline at end of file +} diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt index 56f0197372..b4c1eb6fbf 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt @@ -210,7 +210,7 @@ class NotaryServiceTests { private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) { aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) { - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { val messageData = message.data.deserialize() as? InitialSessionMessage val payload = messageData?.firstPayload!!.deserialize() diff --git a/samples/irs-demo/build.gradle b/samples/irs-demo/build.gradle index e781ba6d81..938dd005af 100644 --- a/samples/irs-demo/build.gradle +++ b/samples/irs-demo/build.gradle @@ -18,6 +18,7 @@ ext['artemis.version'] = "$artemis_version" ext['hibernate.version'] = "$hibernate_version" ext['selenium.version'] = "$selenium_version" ext['jackson.version'] = "$jackson_version" +ext['dropwizard-metrics.version'] = "$metrics_version" apply plugin: 'java' apply plugin: 'kotlin' diff --git a/samples/network-visualiser/build.gradle b/samples/network-visualiser/build.gradle index 543f5dbf92..25b8c89ae7 100644 --- a/samples/network-visualiser/build.gradle +++ b/samples/network-visualiser/build.gradle @@ -18,6 +18,7 @@ buildscript { ext['artemis.version'] = "$artemis_version" ext['hibernate.version'] = "$hibernate_version" ext['jackson.version'] = "$jackson_version" +ext['dropwizard-metrics.version'] = "$metrics_version" apply plugin: 'java' diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index adbd5175dd..5fe858355f 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -18,10 +18,8 @@ import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace -import net.corda.node.services.messaging.Message -import net.corda.node.services.messaging.MessageHandlerRegistration -import net.corda.node.services.messaging.MessagingService -import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.messaging.* +import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.node.internal.InMemoryMessage @@ -290,9 +288,16 @@ class InMemoryMessagingNetwork private constructor( private data class InMemoryReceivedMessage(override val topic: String, override val data: ByteSequence, override val platformVersion: Int, - override val uniqueMessageId: String, + override val uniqueMessageId: DeduplicationId, override val debugTimestamp: Instant, - override val peer: CordaX500Name) : ReceivedMessage + override val peer: CordaX500Name, + override val senderUUID: String? = null, + override val senderSeqNo: Long? = null, + /** Note this flag is never set in the in memory network. */ + override val isSessionInit: Boolean = false) : ReceivedMessage { + + override val additionalHeaders: Map = emptyMap() + } /** * A class that provides an abstraction over the nodes' messaging service that also contains the ability to @@ -319,7 +324,7 @@ class InMemoryMessagingNetwork private constructor( private val peerHandle: PeerHandle, private val executor: AffinityExecutor, private val database: CordaPersistence) : SingletonSerializeAsToken(), InternalMockMessagingService { - private inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration + private inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration @Volatile private var running = true @@ -330,7 +335,7 @@ class InMemoryMessagingNetwork private constructor( } private val state = ThreadBox(InnerState()) - private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) + private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) override val myAddress: PeerHandle get() = peerHandle @@ -353,7 +358,7 @@ class InMemoryMessagingNetwork private constructor( } } - override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { + override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { check(running) val (handler, transfers) = state.locked { val handler = Handler(topic, callback).apply { handlers.add(this) } @@ -374,7 +379,7 @@ class InMemoryMessagingNetwork private constructor( state.locked { check(handlers.remove(registration as Handler)) } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { check(running) msgSend(this, message, target) if (!sendManuallyPumped) { @@ -400,7 +405,7 @@ class InMemoryMessagingNetwork private constructor( override fun cancelRedelivery(retryId: Long) {} /** Returns the given (topic & session, data) pair as a newly created message object. */ - override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message { + override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map): Message { return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId) } @@ -465,7 +470,7 @@ class InMemoryMessagingNetwork private constructor( database.transaction { for (handler in deliverTo) { try { - handler.callback(transfer.toReceivedMessage(), handler) + handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler()) } catch (e: Exception) { log.error("Caught exception in handler for $this/${handler.topicSession}", e) } @@ -489,5 +494,12 @@ class InMemoryMessagingNetwork private constructor( message.debugTimestamp, sender.name) } + + private class DummyDeduplicationHandler : DeduplicationHandler { + override fun afterDatabaseTransaction() { + } + override fun insideDatabaseTransaction() { + } + } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt index 6ccf99c703..52132ebc92 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt @@ -2,6 +2,7 @@ package net.corda.testing.node.internal import net.corda.core.utilities.ByteSequence import net.corda.node.services.messaging.Message +import net.corda.node.services.statemachine.DeduplicationId import java.time.Instant /** @@ -9,7 +10,11 @@ import java.time.Instant */ data class InMemoryMessage(override val topic: String, override val data: ByteSequence, - override val uniqueMessageId: String, - override val debugTimestamp: Instant = Instant.now()) : Message { + override val uniqueMessageId: DeduplicationId, + override val debugTimestamp: Instant = Instant.now(), + override val senderUUID: String? = null) : Message { + + override val additionalHeaders: Map = emptyMap() + override fun toString() = "$topic#${String(data.bytes)}" } \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt index d37db63aa6..3bb7b419d5 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt @@ -15,7 +15,6 @@ import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.createDirectories import net.corda.core.internal.createDirectory import net.corda.core.internal.uncheckedCast -import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient @@ -231,7 +230,7 @@ open class InternalMockNetwork(private val cordappPackages: List, private val entropyRoot = args.entropyRoot var counter = entropyRoot override val log get() = staticLog - override val serverThread: AffinityExecutor = + override val serverThread: AffinityExecutor.ServiceAffinityExecutor = if (mockNet.threadPerNode) { ServiceAffinityExecutor("Mock node $id thread", 1) } else { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt index d18766c4f1..2ffdeba22d 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt @@ -1,18 +1,25 @@ package net.corda.testing.node.internal +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.WritableTransactionStorage import rx.Observable import rx.subjects.PublishSubject -import java.util.HashMap +import java.util.* /** * A class which provides an implementation of [WritableTransactionStorage] which is used in [MockServices] */ open class MockTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() { + override fun trackTransaction(id: SecureHash): CordaFuture { + return txns[id]?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture() + } + override fun track(): DataFeed, SignedTransaction> { return DataFeed(txns.values.toList(), _updatesPublisher) } From 19dad6da96b01d3e5a1c036c148cab17f9890bb2 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Mon, 16 Apr 2018 16:20:10 +0100 Subject: [PATCH 5/9] Add back deprecated functions --- .ci/api-current.txt | 9 +- .../kotlin/net/corda/core/flows/FlowLogic.kt | 104 ++++++++++++++++++ docs/source/api-flows.rst | 62 +++++++++++ .../java/net/corda/docs/FlowCookbookJava.java | 7 ++ .../kotlin/net/corda/docs/FlowCookbook.kt | 7 ++ .../net/corda/docs/LaunchSpaceshipFlow.kt | 99 +++++++++++++++++ .../transitions/TopLevelTransition.kt | 37 +++---- 7 files changed, 297 insertions(+), 28 deletions(-) create mode 100644 docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt diff --git a/.ci/api-current.txt b/.ci/api-current.txt index fac366c8fa..3e83150ab4 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -1200,7 +1200,7 @@ public static final class net.corda.core.flows.FinalityFlow$Companion extends ja @org.jetbrains.annotations.NotNull public net.corda.core.utilities.ProgressTracker childProgressTracker() public static final net.corda.core.flows.FinalityFlow$Companion$NOTARISING INSTANCE ## -@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException +@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException implements net.corda.core.flows.IdentifiableException public () public (String) public (String, Throwable) @@ -1589,9 +1589,10 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec public int hashCode() public String toString() ## -@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException - public (String) - public (String, Throwable) +@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException + public (String, Throwable, long) + @org.jetbrains.annotations.NotNull public Long getErrorId() + public final long getOriginalErrorId() ## @net.corda.core.DoNotImplement @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object public (java.security.PublicKey) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 5050b6253a..9eeddc32d8 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -134,6 +134,110 @@ abstract class FlowLogic { val ourIdentity: Party get() = stateMachine.ourIdentity + // Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we + // create a fresh session for the Party, put it here and use it in subsequent deprecated calls. + private val deprecatedPartySessionMap = HashMap() + private fun getDeprecatedSessionForParty(party: Party): FlowSession { + return deprecatedPartySessionMap.getOrPut(party) { initiateFlow(party) } + } + /** + * Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it + * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. + * + * This method can be called before any send or receive has been done with [otherParty]. In such a case this will force + * them to start their flow. + */ + @Deprecated("Use FlowSession.getCounterpartyFlowInfo()", level = DeprecationLevel.WARNING) + @Suspendable + fun getFlowInfo(otherParty: Party): FlowInfo = getDeprecatedSessionForParty(otherParty).getCounterpartyFlowInfo() + + /** + * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response + * is received, which must be of the given [R] type. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to + * use this when you expect to do a message swap than do use [send] and then [receive] in turn. + * + * @return an [UntrustworthyData] wrapper around the received object. + */ + @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) + inline fun sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData { + return sendAndReceive(R::class.java, otherParty, payload) + } + + /** + * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response + * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data + * should not be trusted until it's been thoroughly verified for consistency and that all expectations are + * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. + * + * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to + * use this when you expect to do a message swap than do use [send] and then [receive] in turn. + * + * @return an [UntrustworthyData] wrapper around the received object. + */ + @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) + @Suspendable + open fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { + return getDeprecatedSessionForParty(otherParty).sendAndReceive(receiveType, payload) + } + + /** + * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. + * + * Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead. + * It is only intended for the case where the [otherParty] is running a distributed service with an idempotent + * flow which only accepts a single request and sends back a single response – e.g. a notary or certain types of + * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered + * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. + */ + @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) + internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { + return getDeprecatedSessionForParty(otherParty).sendAndReceiveWithRetry(payload) + } + + /** + * Suspends until the specified [otherParty] sends us a message of type [R]. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + */ + @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) + inline fun receive(otherParty: Party): UntrustworthyData = receive(R::class.java, otherParty) + + /** + * Suspends until the specified [otherParty] sends us a message of type [receiveType]. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @return an [UntrustworthyData] wrapper around the received object. + */ + @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) + @Suspendable + open fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { + return getDeprecatedSessionForParty(otherParty).receive(receiveType) + } + + /** + * Queues the given [payload] for sending to the [otherParty] and continues without suspending. + * + * Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty] + * is offline then message delivery will be retried until it comes back or until the message is older than the + * network's event horizon time. + */ + @Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING) + @Suspendable + open fun send(otherParty: Party, payload: Any) { + getDeprecatedSessionForParty(otherParty).send(payload) + } + @Suspendable internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { val request = FlowIORequest.SendAndReceive( diff --git a/docs/source/api-flows.rst b/docs/source/api-flows.rst index 6e80c7ae72..8f00b6bc0c 100644 --- a/docs/source/api-flows.rst +++ b/docs/source/api-flows.rst @@ -416,6 +416,68 @@ Our side of the flow must mirror these calls. We could do this as follows: :end-before: DOCEND 08 :dedent: 12 +Why sessions? +^^^^^^^^^^^^^ + +Before ``FlowSession`` s were introduced the send/receive API looked a bit different. They were functions on +``FlowLogic`` and took the address ``Party`` as argument. The platform internally maintained a mapping from ``Party`` to +session, hiding sessions from the user completely. + +Although this is a convenient API it introduces subtle issues where a message that was originally meant for a specific +session may end up in another. + +Consider the following contrived example using the old ``Party`` based API: + +.. container:: codeset + + .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt + :language: kotlin + :start-after: DOCSTART LaunchSpaceshipFlow + :end-before: DOCEND LaunchSpaceshipFlow + +The intention of the flows is very clear: LaunchSpaceshipFlow asks the president whether a spaceship should be launched. +It is expecting a boolean reply. The president in return first tells the secretary that they need coffee, which is also +communicated with a boolean. Afterwards the president replies to the launcher that they don't want to launch. + +However the above can go horribly wrong when the ``launcher`` happens to be the same party ``getSecretary`` returns. In +this case the boolean meant for the secretary will be received by the launcher! + +This indicates that ``Party`` is not a good identifier for the communication sequence, and indeed the ``Party`` based +API may introduce ways for an attacker to fish for information and even trigger unintended control flow like in the +above case. + +Hence we introduced ``FlowSession``, which identifies the communication sequence. With ``FlowSession`` s the above set +of flows would look like this: + +.. container:: codeset + + .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt + :language: kotlin + :start-after: DOCSTART LaunchSpaceshipFlowCorrect + :end-before: DOCEND LaunchSpaceshipFlowCorrect + +Note how the president is now explicit about which session it wants to send to. + +Porting from the old Party-based API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the old API the first ``send`` or ``receive`` to a ``Party`` was the one kicking off the counter-flow. This is now +explicit in the ``initiateFlow`` function call. To port existing code: + +.. container:: codeset + + .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt + :language: kotlin + :start-after: DOCSTART FlowSession porting + :end-before: DOCEND FlowSession porting + :dedent: 8 + + .. literalinclude:: ../../docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java + :language: java + :start-after: DOCSTART FlowSession porting + :end-before: DOCEND FlowSession porting + :dedent: 12 + Subflows -------- Subflows are pieces of reusable flows that may be run by calling ``FlowLogic.subFlow``. There are two broad categories diff --git a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java index d17615841d..bf2c58b858 100644 --- a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java +++ b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java @@ -582,6 +582,13 @@ public class FlowCookbookJava { SignedTransaction notarisedTx2 = subFlow(new FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())); // DOCEND 10 + // DOCSTART FlowSession porting + send(regulator, new Object()); // Old API + // becomes + FlowSession session = initiateFlow(regulator); + session.send(new Object()); + // DOCEND FlowSession porting + return null; } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt index 880570e2df..0528caeaf3 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt @@ -572,6 +572,13 @@ class InitiatorFlow(val arg1: Boolean, val arg2: Int, private val counterparty: val additionalParties: Set = setOf(regulator) val notarisedTx2: SignedTransaction = subFlow(FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())) // DOCEND 10 + + // DOCSTART FlowSession porting + send(regulator, Any()) // Old API + // becomes + val session = initiateFlow(regulator) + session.send(Any()) + // DOCEND FlowSession porting } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt new file mode 100644 index 0000000000..e6826fa213 --- /dev/null +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt @@ -0,0 +1,99 @@ +package net.corda.docs + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.identity.Party +import net.corda.core.utilities.unwrap + +// DOCSTART LaunchSpaceshipFlow +@InitiatingFlow +class LaunchSpaceshipFlow : FlowLogic() { + @Suspendable + override fun call() { + val shouldLaunchSpaceship = receive(getPresident()).unwrap { it } + if (shouldLaunchSpaceship) { + launchSpaceship() + } + } + + fun launchSpaceship() { + } + + fun getPresident(): Party { + TODO() + } +} + +@InitiatedBy(LaunchSpaceshipFlow::class) +@InitiatingFlow +class PresidentSpaceshipFlow(val launcher: Party) : FlowLogic() { + @Suspendable + override fun call() { + val needCoffee = true + send(getSecretary(), needCoffee) + val shouldLaunchSpaceship = false + send(launcher, shouldLaunchSpaceship) + } + + fun getSecretary(): Party { + TODO() + } +} + +@InitiatedBy(PresidentSpaceshipFlow::class) +class SecretaryFlow(val president: Party) : FlowLogic() { + @Suspendable + override fun call() { + // ignore + } +} +// DOCEND LaunchSpaceshipFlow + +// DOCSTART LaunchSpaceshipFlowCorrect +@InitiatingFlow +class LaunchSpaceshipFlowCorrect : FlowLogic() { + @Suspendable + override fun call() { + val presidentSession = initiateFlow(getPresident()) + val shouldLaunchSpaceship = presidentSession.receive().unwrap { it } + if (shouldLaunchSpaceship) { + launchSpaceship() + } + } + + fun launchSpaceship() { + } + + fun getPresident(): Party { + TODO() + } +} + +@InitiatedBy(LaunchSpaceshipFlowCorrect::class) +@InitiatingFlow +class PresidentSpaceshipFlowCorrect(val launcherSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val needCoffee = true + val secretarySession = initiateFlow(getSecretary()) + secretarySession.send(needCoffee) + val shouldLaunchSpaceship = false + launcherSession.send(shouldLaunchSpaceship) + } + + fun getSecretary(): Party { + TODO() + } +} + +@InitiatedBy(PresidentSpaceshipFlowCorrect::class) +class SecretaryFlowCorrect(val presidentSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + // ignore + } +} +// DOCEND LaunchSpaceshipFlowCorrect diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index 684c74b1ab..5260942c36 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -130,30 +130,19 @@ class TopLevelTransition( flowState = FlowState.Started(event.ioRequest, event.fiber), numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1 ) - if (event.maySkipCheckpoint) { - actions.addAll(arrayOf( - Action.CommitTransaction, - Action.ScheduleEvent(Event.DoRemainingWork) - )) - currentState = currentState.copy( - checkpoint = newCheckpoint, - isFlowResumed = false - ) - } else { - actions.addAll(arrayOf( - Action.PersistCheckpoint(context.id, newCheckpoint), - Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), - Action.CommitTransaction, - Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), - Action.ScheduleEvent(Event.DoRemainingWork) - )) - currentState = currentState.copy( - checkpoint = newCheckpoint, - pendingDeduplicationHandlers = emptyList(), - isFlowResumed = false, - isAnyCheckpointPersisted = true - ) - } + actions.addAll(arrayOf( + Action.PersistCheckpoint(context.id, newCheckpoint), + Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), + Action.ScheduleEvent(Event.DoRemainingWork) + )) + currentState = currentState.copy( + checkpoint = newCheckpoint, + pendingDeduplicationHandlers = emptyList(), + isFlowResumed = false, + isAnyCheckpointPersisted = true + ) FlowContinuation.ProcessEvents } } From 57caf9af285e9b0b2a13a1c1342c8e0ff0123983 Mon Sep 17 00:00:00 2001 From: Rick Parker Date: Mon, 16 Apr 2018 13:51:50 +0100 Subject: [PATCH 6/9] Cherry pick 34f871936315097fd54c440912c51ce62b4f922a --- .../net/corda/node/internal/AbstractNode.kt | 2 - .../node/services/statemachine/Action.kt | 6 ++ .../statemachine/ActionExecutorImpl.kt | 5 ++ .../corda/node/services/statemachine/Event.kt | 4 +- .../statemachine/FlowStateMachineImpl.kt | 3 +- .../transitions/ErrorFlowTransition.kt | 1 + .../transitions/TopLevelTransition.kt | 1 + .../node/services/vault/NodeVaultService.kt | 10 ++- .../services/vault/VaultSoftLockManager.kt | 63 ------------------- .../vault/VaultSoftLockManagerTest.kt | 3 + 10 files changed, 30 insertions(+), 68 deletions(-) delete mode 100644 node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 9640a8aa10..2bae06e785 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -58,7 +58,6 @@ import net.corda.node.services.statemachine.* import net.corda.node.services.transactions.* import net.corda.node.services.upgrade.ContractUpgradeServiceImpl import net.corda.node.services.vault.NodeVaultService -import net.corda.node.services.vault.VaultSoftLockManager import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.JVMAgentRegistry import net.corda.node.utilities.NamedThreadFactory @@ -641,7 +640,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, protected open fun makeTransactionStorage(database: CordaPersistence, transactionCacheSizeBytes: Long): WritableTransactionStorage = DBTransactionStorage(transactionCacheSizeBytes) private fun makeVaultObservers(schedulerService: SchedulerService, hibernateConfig: HibernateConfiguration, smm: StateMachineManager, schemaService: SchemaService, flowLogicRefFactory: FlowLogicRefFactory) { - VaultSoftLockManager.install(services.vaultService, smm) ScheduledActivityObserver.install(services.vaultService, schedulerService, flowLogicRefFactory) HibernateObserver.install(services.vaultService.rawUpdates, hibernateConfig, schemaService) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt index 193cdecef1..46332f0362 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -6,6 +6,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowAsyncOperation import net.corda.node.services.messaging.DeduplicationHandler import java.time.Instant +import java.util.* /** * [Action]s are reified IO actions to execute as part of state machine transitions. @@ -117,6 +118,11 @@ sealed class Action { * Execute the specified [operation]. */ data class ExecuteAsyncOperation(val operation: FlowAsyncOperation<*>) : Action() + + /** + * Release soft locks associated with given ID (currently the flow ID). + */ + data class ReleaseSoftLocks(val uuid: UUID?) : Action() } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index 996d5832ad..36e968cf4b 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -72,9 +72,14 @@ class ActionExecutorImpl( is Action.RollbackTransaction -> executeRollbackTransaction() is Action.CommitTransaction -> executeCommitTransaction() is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action) + is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action) } } + private fun executeReleaseSoftLocks(action: Action.ReleaseSoftLocks) { + if (action.uuid != null) services.vaultService.softLockRelease(action.uuid) + } + @Suspendable private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { services.validatedTransactions.trackTransaction(action.hash).thenMatch( diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt index 344a7df1ef..9694e64a09 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -6,6 +6,7 @@ import net.corda.core.internal.FlowIORequest import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.SignedTransaction import net.corda.node.services.messaging.DeduplicationHandler +import java.util.* /** * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from @@ -112,8 +113,9 @@ sealed class Event { * Scheduled by the flow. * * @param returnValue the return value of the flow. + * @param softLocksId the flow ID of the flow if it is holding soft locks, else null. */ - data class FlowFinish(val returnValue: Any?) : Event() + data class FlowFinish(val returnValue: Any?, val softLocksId: UUID?) : Event() /** * Signals the completion of a [FlowAsyncOperation]. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 1e14aa0a36..8654508858 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -182,9 +182,10 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, logger.warn("Flow threw exception", throwable) Try.Failure(throwable) } + val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null val finalEvent = when (resultOrError) { is Try.Success -> { - Event.FlowFinish(resultOrError.value) + Event.FlowFinish(resultOrError.value, softLocksId) } is Try.Failure -> { Event.Error(resultOrError.exception) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index 660f7c6574..97cb3be926 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -62,6 +62,7 @@ class ErrorFlowTransition( } actions.addAll(arrayOf( Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), + Action.ReleaseSoftLocks(context.id.uuid), Action.CommitTransaction, Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index 5260942c36..a0cdc389f2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -167,6 +167,7 @@ class TopLevelTransition( } actions.addAll(arrayOf( Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), + Action.ReleaseSoftLocks(event.softLocksId), Action.CommitTransaction, Action.AcknowledgeMessages(pendingDeduplicationHandlers), Action.RemoveSessionBindings(allSourceSessionIds), diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 3134c57d1b..b1cdf318f5 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -197,9 +197,17 @@ class NodeVaultService( if (!netUpdate.isEmpty()) { recordUpdate(netUpdate) mutex.locked { - // flowId required by SoftLockManager to perform auto-registration of soft locks for new states + // flowId was required by SoftLockManager to perform auto-registration of soft locks for new states val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid val vaultUpdate = if (uuid != null) netUpdate.copy(flowId = uuid) else netUpdate + if (uuid != null) { + val fungible = netUpdate.produced.filter { it.state.data is FungibleAsset<*> } + if (fungible.isNotEmpty()) { + val stateRefs = fungible.map { it.ref }.toNonEmptySet() + log.trace { "Reserving soft locks for flow id $uuid and states $stateRefs" } + softLockReserve(uuid, stateRefs) + } + } updatesPublisher.onNext(vaultUpdate) } } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt deleted file mode 100644 index a1a8b61270..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt +++ /dev/null @@ -1,63 +0,0 @@ -package net.corda.node.services.vault - -import net.corda.core.contracts.FungibleAsset -import net.corda.core.contracts.StateRef -import net.corda.core.flows.FlowLogic -import net.corda.core.node.services.VaultService -import net.corda.core.utilities.* -import net.corda.node.services.statemachine.FlowStateMachineImpl -import net.corda.node.services.statemachine.StateMachineManager -import net.corda.nodeapi.internal.persistence.contextDatabase -import java.util.* - -class VaultSoftLockManager private constructor(private val vault: VaultService) { - companion object { - private val log = contextLogger() - @JvmStatic - fun install(vault: VaultService, smm: StateMachineManager) { - val manager = VaultSoftLockManager(vault) - smm.changes.subscribe { change -> - if (change is StateMachineManager.Change.Removed) { - val logic = change.logic - // Don't run potentially expensive query if the flow didn't lock any states: - if ((logic.stateMachine as FlowStateMachineImpl<*>).hasSoftLockedStates) { - manager.unregisterSoftLocks(logic.runId.uuid, logic) - } - } - } - // Discussion - // - // The intent of the following approach is to support what might be a common pattern in a flow: - // 1. Create state - // 2. Do something with state - // without possibility of another flow intercepting the state between 1 and 2, - // since we cannot lock the state before it exists. e.g. Issue and then Move some Cash. - // - // The downside is we could have a long running flow that holds a lock for a long period of time. - // However, the lock can be programmatically released, like any other soft lock, - // should we want a long running flow that creates a visible state mid way through. - vault.rawUpdates.subscribe { (_, produced, flowId) -> - if (flowId != null) { - val fungible = produced.filter { it.state.data is FungibleAsset<*> } - if (fungible.isNotEmpty()) { - manager.registerSoftLocks(flowId, fungible.map { it.ref }.toNonEmptySet()) - } - } - } - } - } - - private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet) { - log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" } - contextDatabase.transaction { - vault.softLockReserve(flowId, stateRefs) - } - } - - private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) { - log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" } - contextDatabase.transaction { - vault.softLockRelease(flowId) - } - } -} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt index dea24d1d7d..41ee9fad51 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt @@ -84,9 +84,12 @@ class VaultSoftLockManagerTest { private val mockNet = InternalMockNetwork(cordappPackages = listOf(ContractImpl::class.packageName), defaultFactory = { args -> object : InternalMockNetwork.MockNode(args) { override fun makeVaultService(keyManagementService: KeyManagementService, services: ServicesForResolution, hibernateConfig: HibernateConfiguration): VaultServiceInternal { + val node = this val realVault = super.makeVaultService(keyManagementService, services, hibernateConfig) return object : VaultServiceInternal by realVault { override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet?) { + // Should be called before flow is removed + assertEquals(1, node.smm.allStateMachines.size) mockVault.softLockRelease(lockId, stateRefs) // No need to also call the real one for these tests. } } From 5b4fd6fe64a2d08f99ef432b222d7f01a3d745e2 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Wed, 18 Apr 2018 16:44:31 +0100 Subject: [PATCH 7/9] Address comments --- .ci/api-current.txt | 6 +++--- core/src/main/kotlin/net/corda/core/flows/FlowException.kt | 6 ++++-- core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt | 1 - .../kotlin/net/corda/node/services/api/CheckpointStorage.kt | 2 +- .../kotlin/net/corda/node/services/messaging/Messaging.kt | 2 +- .../net/corda/node/services/messaging/P2PMessagingClient.kt | 3 --- .../corda/node/services/persistence/DBTransactionStorage.kt | 4 ++-- .../net/corda/node/services/statemachine/FlowMessaging.kt | 2 -- .../corda/node/services/vault/VaultSoftLockManagerTest.kt | 2 +- 9 files changed, 12 insertions(+), 16 deletions(-) diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 3e83150ab4..8b9039bbd1 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -1590,9 +1590,9 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec public String toString() ## @net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException - public (String, Throwable, long) - @org.jetbrains.annotations.NotNull public Long getErrorId() - public final long getOriginalErrorId() + public (String) + public (String, Throwable) + public (String, Throwable, Long) ## @net.corda.core.DoNotImplement @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object public (java.security.PublicKey) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt index ac0fbdaa23..8dbe892a74 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt @@ -36,7 +36,9 @@ open class FlowException(message: String?, cause: Throwable?) : * that we were not expecting), or the other side had an internal error, or the other side terminated when we * were waiting for a response. */ -class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long) : +class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long?) : CordaRuntimeException(message, cause), IdentifiableException { - override fun getErrorId(): Long = originalErrorId + constructor(message: String, cause: Throwable?) : this(message, cause, null) + constructor(message: String) : this(message, null) + override fun getErrorId(): Long? = originalErrorId } diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 9eeddc32d8..b5c8a1b399 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -131,7 +131,6 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentity: Party get() = stateMachine.ourIdentity // Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index 4a55d7163a..5d0cf99dde 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -22,7 +22,7 @@ interface CheckpointStorage { /** * Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the - * underlying database connection is open, so any processing should happen before it is closed. + * underlying database connection is closed, so any processing should happen before it is closed. */ fun getAllCheckpoints(): Stream>> } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 3e72f52b72..fce17e32f3 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -139,7 +139,7 @@ interface ReceivedMessage : Message { val peer: CordaX500Name /** Platform version of the sender's node. */ val platformVersion: Int - /** UUID representing the sending JVM */ + /** Sequence number of message with respect to senderUUID */ val senderSeqNo: Long? /** True if a flow session init message */ val isSessionInit: Boolean diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 73fe65e66f..189d5db813 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -405,7 +405,6 @@ class P2PMessagingClient(val config: NodeConfiguration, } internal fun deliver(artemisMessage: ClientMessage) { - artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> if (!deduplicator.isDuplicate(cordaMessage)) { deduplicator.signalMessageProcessStart(cordaMessage) @@ -418,7 +417,6 @@ class P2PMessagingClient(val config: NodeConfiguration, } private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) { - state.checkNotLocked() val deliverTo = handlers[msg.topic] if (deliverTo != null) { @@ -600,7 +598,6 @@ class P2PMessagingClient(val config: NodeConfiguration, } override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map): Message { - return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders) } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 9e90665ab9..2bd8cac107 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -96,7 +96,7 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S override fun track(): DataFeed, SignedTransaction> { return txStorage.locked { - DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updates.wrapWithDatabaseTransaction()) } } @@ -104,7 +104,7 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S return txStorage.locked { val existingTransaction = get(id) if (existingTransaction == null) { - updatesPublisher.filter { it.id == id }.toFuture() + updates.filter { it.id == id }.toFuture() } else { doneFuture(existingTransaction.toSignedTx()) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt index 01a90a40d1..3ee44d3803 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -36,7 +36,6 @@ interface FlowMessaging { * Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing. */ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { - companion object { val log = contextLogger() @@ -63,7 +62,6 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { } private fun SessionMessage.additionalHeaders(target: Party): Map { - // This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator. // It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case. val mightDeadlockDrainingTarget = FlowStateMachineImpl.currentStateMachine()?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt index 41ee9fad51..82c507f9ba 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt @@ -89,7 +89,7 @@ class VaultSoftLockManagerTest { return object : VaultServiceInternal by realVault { override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet?) { // Should be called before flow is removed - assertEquals(1, node.smm.allStateMachines.size) + assertEquals(1, node.started!!.smm.allStateMachines.size) mockVault.softLockRelease(lockId, stateRefs) // No need to also call the real one for these tests. } } From 6bf34ed5c7406328c526e4ef36e46b03bad947a8 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Mon, 23 Apr 2018 13:28:34 +0100 Subject: [PATCH 8/9] Fix bugs --- .../main/kotlin/net/corda/core/flows/FlowLogic.kt | 14 -------------- .../services/persistence/DBTransactionStorage.kt | 2 +- .../services/statemachine/FlowStateMachineImpl.kt | 5 ++--- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index b5c8a1b399..29a84a6812 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -185,20 +185,6 @@ abstract class FlowLogic { return getDeprecatedSessionForParty(otherParty).sendAndReceive(receiveType, payload) } - /** - * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. - * - * Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead. - * It is only intended for the case where the [otherParty] is running a distributed service with an idempotent - * flow which only accepts a single request and sends back a single response – e.g. a notary or certain types of - * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered - * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. - */ - @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) - internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return getDeprecatedSessionForParty(otherParty).sendAndReceiveWithRetry(payload) - } - /** * Suspends until the specified [otherParty] sends us a message of type [R]. * diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 2bd8cac107..4dd4d4b8b6 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -96,7 +96,7 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S override fun track(): DataFeed, SignedTransaction> { return txStorage.locked { - DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updates.wrapWithDatabaseTransaction()) + DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 8654508858..aabc75bb4c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -297,7 +297,6 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): R { val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) val transaction = extractThreadLocalTransaction() - val transitionExecutor = TransientReference(getTransientField(TransientValues::transitionExecutor)) parkAndSerialize { _, _ -> logger.trace { "Suspended on $ioRequest" } @@ -312,8 +311,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, Event.Error(throwable) } - // We must commit the database transaction before returning from this closure, otherwise Quasar may schedule - // other fibers + // We must commit the database transaction before returning from this closure otherwise Quasar may schedule + // other fibers, so we process the event immediately val continuation = processEventImmediately( event, isDbTransactionOpenOnEntry = true, From b0d2a258c0fa9db726c8ecb9e85b93af019ae900 Mon Sep 17 00:00:00 2001 From: Chris Burlinchon Date: Mon, 16 Apr 2018 18:05:01 +0100 Subject: [PATCH 9/9] cherry-pick 7759fdbb71ea9b2021afd8af0ac05447c5305b3a --- .../net/corda/node/internal/AbstractNode.kt | 4 +-- .../security/RPCSecurityManagerImpl.kt | 3 +-- .../services/events/NodeSchedulerService.kt | 25 +------------------ .../services/messaging/P2PMessagingClient.kt | 5 ---- .../SingleThreadedStateMachineManager.kt | 14 ----------- .../statemachine/StateMachineManager.kt | 5 ---- ...FiberDeserializationCheckingInterceptor.kt | 1 - 7 files changed, 4 insertions(+), 53 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 2bae06e785..b20b0022f7 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -242,7 +242,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val notaryService = makeNotaryService(nodeServices, database) val smm = makeStateMachineManager(database) val flowLogicRefFactory = FlowLogicRefFactoryImpl(cordappLoader.appClassLoader) - val flowStarter = FlowStarterImpl(serverThread, smm, flowLogicRefFactory) + val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory) val schedulerService = NodeSchedulerService( platformClock, database, @@ -893,7 +893,7 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } } -internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { +internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { override fun startFlow(logic: FlowLogic, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture> { return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler) } diff --git a/node/src/main/kotlin/net/corda/node/internal/security/RPCSecurityManagerImpl.kt b/node/src/main/kotlin/net/corda/node/internal/security/RPCSecurityManagerImpl.kt index cf961636f9..4b9af66259 100644 --- a/node/src/main/kotlin/net/corda/node/internal/security/RPCSecurityManagerImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/security/RPCSecurityManagerImpl.kt @@ -77,8 +77,7 @@ class RPCSecurityManagerImpl(config: AuthServiceConfig) : RPCSecurityManager { * Instantiate RPCSecurityManager initialised with users data from a list of [User] */ fun fromUserList(id: AuthServiceId, users: List) = - RPCSecurityManagerImpl( - AuthServiceConfig.fromUsers(users).copy(id = id)) + RPCSecurityManagerImpl(AuthServiceConfig.fromUsers(users).copy(id = id)) // Build internal Shiro securityManager instance private fun buildImpl(config: AuthServiceConfig): DefaultSecurityManager { diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index cea738df76..bfada7354d 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -178,29 +178,6 @@ class NodeSchedulerService(private val clock: CordaClock, } } - /** - * Stop scheduler service. - */ - fun stop() { - mutex.locked { - schedulerTimerExecutor.shutdown() - scheduledStatesQueue.clear() - scheduledStates.clear() - } - } - - /** - * Resume scheduler service after having called [stop]. - */ - fun resume() { - mutex.locked { - schedulerTimerExecutor = Executors.newSingleThreadExecutor() - scheduledStates.putAll(createMap()) - scheduledStatesQueue.addAll(scheduledStates.values) - rescheduleWakeUp() - } - } - override fun scheduleStateActivity(action: ScheduledStateRef) { log.trace { "Schedule $action" } val previousState = scheduledStates[action.ref] @@ -240,7 +217,7 @@ class NodeSchedulerService(private val clock: CordaClock, } } - private var schedulerTimerExecutor = Executors.newSingleThreadExecutor() + private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() /** * This method first cancels the [java.util.concurrent.Future] for any pending action so that the * [awaitWithDeadline] used below drops through without running the action. We then create a new diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 189d5db813..88fa4d8e1f 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -5,7 +5,6 @@ import com.codahale.metrics.MetricRegistry import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.internal.ThreadBox -import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient @@ -333,8 +332,6 @@ class P2PMessagingClient(val config: NodeConfiguration, private val shutdownLatch = CountDownLatch(1) - var runningFuture = openFuture() - /** * Starts the p2p event loop: this method only returns once [stop] has been called. */ @@ -345,7 +342,6 @@ class P2PMessagingClient(val config: NodeConfiguration, check(started) { "start must be called first" } check(!running) { "run can't be called twice" } running = true - runningFuture.set(Unit) // If it's null, it means we already called stop, so return immediately. if (p2pConsumer == null) { return @@ -457,7 +453,6 @@ class P2PMessagingClient(val config: NodeConfiguration, check(started) val prevRunning = running running = false - runningFuture = openFuture() networkChangeSubscription?.unsubscribe() require(p2pConsumer != null, { "stop can't be called twice" }) require(producer != null, { "stop can't be called twice" }) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index 4c579c9ba8..f7460f4bab 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -134,20 +134,6 @@ class SingleThreadedStateMachineManager( } } - override fun resume() { - fiberDeserializationChecker?.start(checkpointSerializationContext!!) - val fibers = restoreFlowsFromCheckpoints() - Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> - (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) - } - serviceHub.networkMapCache.nodeReady.then { - resumeRestoredFlows(fibers) - } - mutex.locked { - stopping = false - } - } - override fun > findStateMachines(flowClass: Class): List>> { return mutex.locked { flows.values.mapNotNull { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 7d645407a1..96edb248b0 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -39,11 +39,6 @@ interface StateMachineManager { */ fun stop(allowedUnsuspendedFiberCount: Int) - /** - * Resume state machine manager after having called [stop]. - */ - fun resume() - /** * Starts a new flow. * diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt index 2a9b96cc5b..cbde382f4d 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt @@ -89,7 +89,6 @@ class FiberDeserializationChecker { fun stop(): Boolean { jobQueue.add(Job.Finish) checkerThread?.join() - checkerThread = null return foundUnrestorableFibers } }