mirror of
https://github.com/corda/corda.git
synced 2025-01-31 00:24:59 +00:00
Merge pull request #2964 from corda/CORDA-1334/aslemmer-enterprise-smm-port
CORDA-1334: port enterprise statemachine
This commit is contained in:
commit
640e5c6088
@ -1200,7 +1200,7 @@ public static final class net.corda.core.flows.FinalityFlow$Companion extends ja
|
||||
@org.jetbrains.annotations.NotNull public net.corda.core.utilities.ProgressTracker childProgressTracker()
|
||||
public static final net.corda.core.flows.FinalityFlow$Companion$NOTARISING INSTANCE
|
||||
##
|
||||
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException
|
||||
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException implements net.corda.core.flows.IdentifiableException
|
||||
public <init>()
|
||||
public <init>(String)
|
||||
public <init>(String, Throwable)
|
||||
@ -1589,9 +1589,10 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec
|
||||
public int hashCode()
|
||||
public String toString()
|
||||
##
|
||||
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException
|
||||
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException
|
||||
public <init>(String)
|
||||
public <init>(String, Throwable)
|
||||
public <init>(String, Throwable, Long)
|
||||
##
|
||||
@net.corda.core.DoNotImplement @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object
|
||||
public <init>(java.security.PublicKey)
|
||||
|
@ -47,6 +47,7 @@ buildscript {
|
||||
ext.bouncycastle_version = constants.getProperty("bouncycastleVersion")
|
||||
ext.guava_version = constants.getProperty("guavaVersion")
|
||||
ext.caffeine_version = constants.getProperty("caffeineVersion")
|
||||
ext.metrics_version = constants.getProperty("metricsVersion")
|
||||
ext.okhttp_version = '3.5.0'
|
||||
ext.netty_version = '4.1.9.Final'
|
||||
ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion")
|
||||
|
@ -8,3 +8,4 @@ jsr305Version=3.0.2
|
||||
artifactoryPluginVersion=4.4.18
|
||||
snakeYamlVersion=1.19
|
||||
caffeineVersion=2.6.2
|
||||
metricsVersion=3.2.5
|
||||
|
@ -0,0 +1,16 @@
|
||||
package net.corda.core.flows;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* An exception that may be identified with an ID. If an exception originates in a counter-flow this ID will be
|
||||
* propagated. This allows correlation of error conditions across different flows.
|
||||
*/
|
||||
public interface IdentifiableException {
|
||||
/**
|
||||
* @return the ID of the error, or null if the error doesn't have it set (yet).
|
||||
*/
|
||||
default @Nullable Long getErrorId() {
|
||||
return null;
|
||||
}
|
||||
}
|
@ -7,16 +7,27 @@ import net.corda.core.CordaRuntimeException
|
||||
/**
|
||||
* Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end.
|
||||
* The exception will propagate to all counterparty flows and will be thrown on their end the next time they wait on a
|
||||
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already ended,
|
||||
* will not receive the exception (if this is required then have them wait for a confirmation message).
|
||||
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already
|
||||
* ended, will not receive the exception (if this is required then have them wait for a confirmation message).
|
||||
*
|
||||
* If the *rethrown* [FlowException] is uncaught in counterparty flows and propagation triggers then the exception is
|
||||
* downgraded to an [UnexpectedFlowEndException]. This means only immediate counterparty flows will receive information
|
||||
* about what the exception was.
|
||||
*
|
||||
* [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service.
|
||||
* It is recommended a [FlowLogic] document the [FlowException] types it can throw.
|
||||
*
|
||||
* @property originalErrorId the ID backing [getErrorId]. If null it will be set dynamically by the flow framework when
|
||||
* the exception is handled. This ID is propagated to counterparty flows, even when the [FlowException] is
|
||||
* downgraded to an [UnexpectedFlowEndException]. This is so the error conditions may be correlated later on.
|
||||
*/
|
||||
open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) {
|
||||
open class FlowException(message: String?, cause: Throwable?) :
|
||||
CordaException(message, cause), IdentifiableException {
|
||||
constructor(message: String?) : this(message, null)
|
||||
constructor(cause: Throwable?) : this(cause?.toString(), cause)
|
||||
constructor() : this(null, null)
|
||||
var originalErrorId: Long? = null
|
||||
override fun getErrorId(): Long? = originalErrorId
|
||||
}
|
||||
// DOCEND 1
|
||||
|
||||
@ -25,6 +36,9 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m
|
||||
* that we were not expecting), or the other side had an internal error, or the other side terminated when we
|
||||
* were waiting for a response.
|
||||
*/
|
||||
class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
|
||||
constructor(msg: String) : this(msg, null)
|
||||
}
|
||||
class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long?) :
|
||||
CordaRuntimeException(message, cause), IdentifiableException {
|
||||
constructor(message: String, cause: Throwable?) : this(message, cause, null)
|
||||
constructor(message: String) : this(message, null)
|
||||
override fun getErrorId(): Long? = originalErrorId
|
||||
}
|
||||
|
@ -6,20 +6,17 @@ import net.corda.core.CordaInternal
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.abbreviate
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.node.NodeInfo
|
||||
import net.corda.core.node.ServiceHub
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.*
|
||||
import org.slf4j.Logger
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* A sub-class of [FlowLogic<T>] implements a flow using direct, straight line blocking code. Thus you
|
||||
@ -77,12 +74,19 @@ abstract class FlowLogic<out T> {
|
||||
*/
|
||||
@Suspendable
|
||||
@JvmStatic
|
||||
@JvmOverloads
|
||||
@Throws(FlowException::class)
|
||||
fun sleep(duration: Duration) {
|
||||
fun sleep(duration: Duration, maySkipCheckpoint: Boolean = false) {
|
||||
if (duration > Duration.ofMinutes(5)) {
|
||||
throw FlowException("Attempt to sleep for longer than 5 minutes is not supported. Consider using SchedulableState.")
|
||||
}
|
||||
(Strand.currentStrand() as? FlowStateMachine<*>)?.sleepUntil(Instant.now() + duration) ?: Strand.sleep(duration.toMillis())
|
||||
val fiber = (Strand.currentStrand() as? FlowStateMachine<*>)
|
||||
if (fiber == null) {
|
||||
Strand.sleep(duration.toMillis())
|
||||
} else {
|
||||
val request = FlowIORequest.Sleep(wakeUpAfter = fiber.serviceHub.clock.instant() + duration)
|
||||
fiber.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,7 +98,7 @@ abstract class FlowLogic<out T> {
|
||||
|
||||
/**
|
||||
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
|
||||
* only available once the flow has started, which means it cannot be accessed in the constructor. Either
|
||||
* access this lazily or from inside [call].
|
||||
*/
|
||||
val serviceHub: ServiceHub get() = stateMachine.serviceHub
|
||||
@ -104,7 +108,7 @@ abstract class FlowLogic<out T> {
|
||||
* that this function does not communicate in itself, the counter-flow will be kicked off by the first send/receive.
|
||||
*/
|
||||
@Suspendable
|
||||
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions)
|
||||
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party)
|
||||
|
||||
/**
|
||||
* Specifies the identity, with certificate, to use for this flow. This will be one of the multiple identities that
|
||||
@ -114,7 +118,10 @@ abstract class FlowLogic<out T> {
|
||||
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
|
||||
* is implemented.
|
||||
*/
|
||||
val ourIdentityAndCert: PartyAndCertificate get() = stateMachine.ourIdentityAndCert
|
||||
val ourIdentityAndCert: PartyAndCertificate get() {
|
||||
return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity }
|
||||
?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!")
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node.
|
||||
@ -124,8 +131,14 @@ abstract class FlowLogic<out T> {
|
||||
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
|
||||
* is implemented.
|
||||
*/
|
||||
val ourIdentity: Party get() = ourIdentityAndCert.party
|
||||
val ourIdentity: Party get() = stateMachine.ourIdentity
|
||||
|
||||
// Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we
|
||||
// create a fresh session for the Party, put it here and use it in subsequent deprecated calls.
|
||||
private val deprecatedPartySessionMap = HashMap<Party, FlowSession>()
|
||||
private fun getDeprecatedSessionForParty(party: Party): FlowSession {
|
||||
return deprecatedPartySessionMap.getOrPut(party) { initiateFlow(party) }
|
||||
}
|
||||
/**
|
||||
* Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it
|
||||
* provides the necessary information needed for the evolution of flows and enabling backwards compatibility.
|
||||
@ -133,9 +146,9 @@ abstract class FlowLogic<out T> {
|
||||
* This method can be called before any send or receive has been done with [otherParty]. In such a case this will force
|
||||
* them to start their flow.
|
||||
*/
|
||||
@Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING)
|
||||
@Deprecated("Use FlowSession.getCounterpartyFlowInfo()", level = DeprecationLevel.WARNING)
|
||||
@Suspendable
|
||||
fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false)
|
||||
fun getFlowInfo(otherParty: Party): FlowInfo = getDeprecatedSessionForParty(otherParty).getCounterpartyFlowInfo()
|
||||
|
||||
/**
|
||||
* Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response
|
||||
@ -169,31 +182,7 @@ abstract class FlowLogic<out T> {
|
||||
@Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING)
|
||||
@Suspendable
|
||||
open fun <R : Any> sendAndReceive(receiveType: Class<R>, otherParty: Party, payload: Any): UntrustworthyData<R> {
|
||||
return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received.
|
||||
*
|
||||
* Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead.
|
||||
* It is only intended for the case where the [otherParty] is running a distributed service with an idempotent
|
||||
* flow which only accepts a single request and sends back a single response – e.g. a notary or certain types of
|
||||
* oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered
|
||||
* to a different one, so there is no need to wait until the initial node comes back up to obtain a response.
|
||||
*/
|
||||
@Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING)
|
||||
internal inline fun <reified R : Any> sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData<R> {
|
||||
return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
|
||||
return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
internal inline fun <reified R : Any> FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData<R> {
|
||||
return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
|
||||
return getDeprecatedSessionForParty(otherParty).sendAndReceive(receiveType, payload)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -218,9 +207,37 @@ abstract class FlowLogic<out T> {
|
||||
@Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING)
|
||||
@Suspendable
|
||||
open fun <R : Any> receive(receiveType: Class<R>, otherParty: Party): UntrustworthyData<R> {
|
||||
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false)
|
||||
return getDeprecatedSessionForParty(otherParty).receive(receiveType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
|
||||
*
|
||||
* Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty]
|
||||
* is offline then message delivery will be retried until it comes back or until the message is older than the
|
||||
* network's event horizon time.
|
||||
*/
|
||||
@Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING)
|
||||
@Suspendable
|
||||
open fun send(otherParty: Party, payload: Any) {
|
||||
getDeprecatedSessionForParty(otherParty).send(payload)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
|
||||
val request = FlowIORequest.SendAndReceive(
|
||||
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
|
||||
shouldRetrySend = true
|
||||
)
|
||||
return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
internal inline fun <reified R : Any> FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData<R> {
|
||||
return sendAndReceiveWithRetry(R::class.java, payload)
|
||||
}
|
||||
|
||||
|
||||
/** Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
|
||||
@ -232,8 +249,14 @@ abstract class FlowLogic<out T> {
|
||||
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
|
||||
*/
|
||||
@Suspendable
|
||||
open fun receiveAllMap(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
return stateMachine.receiveAll(sessions, this)
|
||||
@JvmOverloads
|
||||
open fun receiveAllMap(sessions: Map<FlowSession, Class<out Any>>, maySkipCheckpoint: Boolean = false): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
enforceNoPrimitiveInReceive(sessions.values)
|
||||
val replies = stateMachine.suspend(
|
||||
ioRequest = FlowIORequest.Receive(sessions.keys.toNonEmptySet()),
|
||||
maySkipCheckpoint = maySkipCheckpoint
|
||||
)
|
||||
return replies.mapValues { (session, payload) -> payload.checkPayloadIs(sessions[session]!!) }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -248,24 +271,13 @@ abstract class FlowLogic<out T> {
|
||||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||
*/
|
||||
@Suspendable
|
||||
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
|
||||
@JvmOverloads
|
||||
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>, maySkipCheckpoint: Boolean = false): List<UntrustworthyData<R>> {
|
||||
enforceNoPrimitiveInReceive(listOf(receiveType))
|
||||
enforceNoDuplicates(sessions)
|
||||
return castMapValuesToKnownType(receiveAllMap(associateSessionsToReceiveType(receiveType, sessions)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
|
||||
*
|
||||
* Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty]
|
||||
* is offline then message delivery will be retried until it comes back or until the message is older than the
|
||||
* network's event horizon time.
|
||||
*/
|
||||
@Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING)
|
||||
@Suspendable
|
||||
open fun send(otherParty: Party, payload: Any) {
|
||||
stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes the given subflow. This function returns once the subflow completes successfully with the result
|
||||
* returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the
|
||||
@ -283,11 +295,8 @@ abstract class FlowLogic<out T> {
|
||||
open fun <R> subFlow(subLogic: FlowLogic<R>): R {
|
||||
subLogic.stateMachine = stateMachine
|
||||
maybeWireUpProgressTracking(subLogic)
|
||||
if (!subLogic.javaClass.isAnnotationPresent(InitiatingFlow::class.java)) {
|
||||
subLogic.flowUsedForSessions = flowUsedForSessions
|
||||
}
|
||||
logger.debug { "Calling subflow: $subLogic" }
|
||||
val result = subLogic.call()
|
||||
val result = stateMachine.subFlow(subLogic)
|
||||
logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" }
|
||||
// It's easy to forget this when writing flows so we just step it to the DONE state when it completes.
|
||||
subLogic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
@ -384,7 +393,8 @@ abstract class FlowLogic<out T> {
|
||||
@Suspendable
|
||||
@JvmOverloads
|
||||
fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction {
|
||||
return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint)
|
||||
val request = FlowIORequest.WaitForLedgerCommit(hash)
|
||||
return stateMachine.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -427,11 +437,6 @@ abstract class FlowLogic<out T> {
|
||||
_stateMachine = value
|
||||
}
|
||||
|
||||
// This is the flow used for managing sessions. It defaults to the current flow but if this is an inlined sub-flow
|
||||
// then it will point to the flow it's been inlined to.
|
||||
@Suppress("LeakingThis")
|
||||
private var flowUsedForSessions: FlowLogic<*> = this
|
||||
|
||||
private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) {
|
||||
val ours = progressTracker
|
||||
val theirs = subLogic.progressTracker
|
||||
@ -448,6 +453,11 @@ abstract class FlowLogic<out T> {
|
||||
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
|
||||
}
|
||||
|
||||
private fun enforceNoPrimitiveInReceive(types: Collection<Class<*>>) {
|
||||
val primitiveTypes = types.filter { it.isPrimitive }
|
||||
require(primitiveTypes.isEmpty()) { "Cannot receive primitive type(s) $primitiveTypes" }
|
||||
}
|
||||
|
||||
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
|
||||
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
|
||||
}
|
||||
@ -472,4 +482,4 @@ data class FlowInfo(
|
||||
* to deduplicate it from other releases of the same CorDapp, typically a version string. See the
|
||||
* [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details.
|
||||
*/
|
||||
val appName: String)
|
||||
val appName: String)
|
||||
|
@ -0,0 +1,23 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
|
||||
/**
|
||||
* Interface for arbitrary operations that can be invoked in a flow asynchronously - the flow will suspend until the
|
||||
* operation completes. Operation parameters are expected to be injected via constructor.
|
||||
*/
|
||||
@CordaSerializable
|
||||
interface FlowAsyncOperation<R : Any> {
|
||||
/** Performs the operation in a non-blocking fashion. */
|
||||
fun execute(): CordaFuture<R>
|
||||
}
|
||||
|
||||
/** Executes the specified [operation] and suspends until operation completion. */
|
||||
@Suspendable
|
||||
fun <T, R : Any> FlowLogic<T>.executeAsync(operation: FlowAsyncOperation<R>, maySkipCheckpoint: Boolean = false): R {
|
||||
val request = FlowIORequest.ExecuteAsyncOperation(operation)
|
||||
return stateMachine.suspend(request, maySkipCheckpoint)
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.NonEmptySet
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* A [FlowIORequest] represents an IO request of a flow when it suspends. It is persisted in checkpoints.
|
||||
*/
|
||||
sealed class FlowIORequest<out R : Any> {
|
||||
/**
|
||||
* Send messages to sessions.
|
||||
*
|
||||
* @property sessionToMessage a map from session to message-to-be-sent.
|
||||
* @property shouldRetrySend specifies whether the send should be retried.
|
||||
*/
|
||||
data class Send(
|
||||
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
|
||||
val shouldRetrySend: Boolean
|
||||
) : FlowIORequest<Unit>() {
|
||||
override fun toString() = "Send(" +
|
||||
"sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " +
|
||||
"shouldRetrySend=$shouldRetrySend" +
|
||||
")"
|
||||
}
|
||||
|
||||
/**
|
||||
* Receive messages from sessions.
|
||||
*
|
||||
* @property sessions the sessions to receive messages from.
|
||||
* @return a map from session to received message.
|
||||
*/
|
||||
data class Receive(
|
||||
val sessions: NonEmptySet<FlowSession>
|
||||
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>()
|
||||
|
||||
/**
|
||||
* Send and receive messages from the specified sessions.
|
||||
*
|
||||
* @property sessionToMessage a map from session to message-to-be-sent. The keys also specify which sessions to
|
||||
* receive from.
|
||||
* @property shouldRetrySend specifies whether the send should be retried.
|
||||
* @return a map from session to received message.
|
||||
*/
|
||||
data class SendAndReceive(
|
||||
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
|
||||
val shouldRetrySend: Boolean
|
||||
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>() {
|
||||
override fun toString() = "SendAndReceive(${sessionToMessage.mapValues { (key, value) ->
|
||||
"$key=${value.hash}" }}, shouldRetrySend=$shouldRetrySend)"
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for a transaction to be committed to the database.
|
||||
*
|
||||
* @property hash the hash of the transaction.
|
||||
* @return the committed transaction.
|
||||
*/
|
||||
data class WaitForLedgerCommit(val hash: SecureHash) : FlowIORequest<SignedTransaction>()
|
||||
|
||||
/**
|
||||
* Get the FlowInfo of the specified sessions.
|
||||
*
|
||||
* @property sessions the sessions to get the FlowInfo of.
|
||||
* @return a map from session to FlowInfo.
|
||||
*/
|
||||
data class GetFlowInfo(val sessions: NonEmptySet<FlowSession>) : FlowIORequest<Map<FlowSession, FlowInfo>>()
|
||||
|
||||
/**
|
||||
* Suspend the flow until the specified time.
|
||||
*
|
||||
* @property wakeUpAfter the time to sleep until.
|
||||
*/
|
||||
data class Sleep(val wakeUpAfter: Instant) : FlowIORequest<Unit>()
|
||||
|
||||
/**
|
||||
* Suspend the flow until all Initiating sessions are confirmed.
|
||||
*/
|
||||
object WaitForSessionConfirmations : FlowIORequest<Unit>()
|
||||
|
||||
/**
|
||||
* Execute the specified [operation], suspend the flow until completion.
|
||||
*/
|
||||
data class ExecuteAsyncOperation<T : Any>(val operation: FlowAsyncOperation<T>) : FlowIORequest<T>()
|
||||
}
|
@ -1,64 +1,42 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.DoNotImplement
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.node.ServiceHub
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import org.slf4j.Logger
|
||||
import java.time.Instant
|
||||
|
||||
/** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */
|
||||
interface FlowStateMachine<R> {
|
||||
@DoNotImplement
|
||||
interface FlowStateMachine<FLOWRETURN> {
|
||||
@Suspendable
|
||||
fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo
|
||||
fun <SUSPENDRETURN : Any> suspend(ioRequest: FlowIORequest<SUSPENDRETURN>, maySkipCheckpoint: Boolean): SUSPENDRETURN
|
||||
|
||||
@Suspendable
|
||||
fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession
|
||||
|
||||
@Suspendable
|
||||
fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
payload: Any,
|
||||
sessionFlow: FlowLogic<*>,
|
||||
retrySend: Boolean,
|
||||
maySkipCheckpoint: Boolean): UntrustworthyData<T>
|
||||
|
||||
@Suspendable
|
||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData<T>
|
||||
|
||||
@Suspendable
|
||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean)
|
||||
|
||||
@Suspendable
|
||||
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction
|
||||
|
||||
@Suspendable
|
||||
fun sleepUntil(until: Instant)
|
||||
fun initiateFlow(party: Party): FlowSession
|
||||
|
||||
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
|
||||
|
||||
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
|
||||
|
||||
@Suspendable
|
||||
fun <SUBFLOWRETURN> subFlow(subFlow: FlowLogic<SUBFLOWRETURN>): SUBFLOWRETURN
|
||||
|
||||
@Suspendable
|
||||
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
|
||||
|
||||
@Suspendable
|
||||
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
|
||||
|
||||
val logic: FlowLogic<R>
|
||||
val logic: FlowLogic<FLOWRETURN>
|
||||
val serviceHub: ServiceHub
|
||||
val logger: Logger
|
||||
val id: StateMachineRunId
|
||||
val resultFuture: CordaFuture<R>
|
||||
val resultFuture: CordaFuture<FLOWRETURN>
|
||||
val context: InvocationContext
|
||||
val ourIdentityAndCert: PartyAndCertificate
|
||||
|
||||
@Suspendable
|
||||
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
|
||||
val ourIdentity: Party
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package net.corda.core.node.services
|
||||
|
||||
import net.corda.core.DoNotImplement
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
@ -26,4 +27,9 @@ interface TransactionStorage {
|
||||
* Returns all currently stored transactions and further fresh ones.
|
||||
*/
|
||||
fun track(): DataFeed<List<SignedTransaction>, SignedTransaction>
|
||||
|
||||
/**
|
||||
* Returns a future that completes with the transaction corresponding to [id] once it has been committed
|
||||
*/
|
||||
fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction>
|
||||
}
|
@ -2,6 +2,9 @@ package net.corda.core.utilities
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import java.io.Serializable
|
||||
|
||||
/**
|
||||
@ -29,3 +32,15 @@ class UntrustworthyData<out T>(@PublishedApi internal val fromUntrustedWorld: T)
|
||||
}
|
||||
|
||||
inline fun <T, R> UntrustworthyData<T>.unwrap(validator: (T) -> R): R = validator(fromUntrustedWorld)
|
||||
|
||||
fun <T : Any> SerializedBytes<Any>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||
val payloadData: T = try {
|
||||
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
|
||||
serializer.deserialize(this, type, SerializationDefaults.P2P_CONTEXT)
|
||||
} catch (ex: Exception) {
|
||||
throw IllegalArgumentException("Payload invalid", ex)
|
||||
}
|
||||
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
|
||||
throw IllegalArgumentException("We were expecting a ${type.name} but we instead got a " +
|
||||
"${payloadData.javaClass.name} (${payloadData})")
|
||||
}
|
||||
|
@ -61,9 +61,8 @@ public class FlowsInJavaTest {
|
||||
fail("ExecutionException should have been thrown");
|
||||
} catch (ExecutionException e) {
|
||||
assertThat(e.getCause())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("primitive")
|
||||
.hasMessageContaining(receiveType.getName());
|
||||
.hasMessageContaining(Primitives.unwrap(receiveType).getName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ dependencies {
|
||||
}
|
||||
|
||||
// Coda Hale's Metrics: for monitoring of key statistics
|
||||
compile "io.dropwizard.metrics:metrics-core:3.1.2"
|
||||
compile "io.dropwizard.metrics:metrics-core:$metrics_version"
|
||||
|
||||
// JimFS: in memory java.nio filesystem. Used for test and simulation utilities.
|
||||
compile "com.google.jimfs:jimfs:1.1"
|
||||
|
@ -144,7 +144,7 @@ class P2PMessagingTest {
|
||||
|
||||
distributedServiceNodes.forEach {
|
||||
val nodeName = it.services.myInfo.legalIdentitiesAndCerts.first().name
|
||||
it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _ ->
|
||||
it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
|
||||
crashingNodes.requestsReceived.incrementAndGet()
|
||||
crashingNodes.firstRequestReceived.countDown()
|
||||
// The node which receives the first request will ignore all requests
|
||||
@ -159,6 +159,7 @@ class P2PMessagingTest {
|
||||
val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes)
|
||||
it.internalServices.networkService.send(response, request.replyTo)
|
||||
}
|
||||
handler.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
return crashingNodes
|
||||
@ -186,10 +187,11 @@ class P2PMessagingTest {
|
||||
}
|
||||
|
||||
private fun InProcess.respondWith(message: Any) {
|
||||
internalServices.networkService.addMessageHandler("test.request") { netMessage, _ ->
|
||||
internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
|
||||
val request = netMessage.data.deserialize<TestRequest>()
|
||||
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes)
|
||||
internalServices.networkService.send(response, request.replyTo)
|
||||
handler.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,11 +213,12 @@ class P2PMessagingTest {
|
||||
*/
|
||||
inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) {
|
||||
val consumed = AtomicBoolean()
|
||||
addMessageHandler(topic) { msg, reg ->
|
||||
addMessageHandler(topic) { msg, reg, handler ->
|
||||
removeMessageHandler(reg)
|
||||
check(!consumed.getAndSet(true)) { "Called more than once" }
|
||||
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
|
||||
callback(msg)
|
||||
handler.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@ import net.corda.confidential.SwapIdentitiesHandler
|
||||
import net.corda.core.CordaException
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.crypto.sign
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
@ -47,6 +48,7 @@ import net.corda.node.services.events.NodeSchedulerService
|
||||
import net.corda.node.services.events.ScheduledActivityObserver
|
||||
import net.corda.node.services.identity.PersistentIdentityService
|
||||
import net.corda.node.services.keys.PersistentKeyManagementService
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.network.*
|
||||
import net.corda.node.services.persistence.*
|
||||
@ -56,7 +58,6 @@ import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.transactions.*
|
||||
import net.corda.node.services.upgrade.ContractUpgradeServiceImpl
|
||||
import net.corda.node.services.vault.NodeVaultService
|
||||
import net.corda.node.services.vault.VaultSoftLockManager
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.JVMAgentRegistry
|
||||
import net.corda.node.utilities.NamedThreadFactory
|
||||
@ -131,7 +132,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
|
||||
|
||||
// We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the
|
||||
// low-performance prototyping period.
|
||||
protected abstract val serverThread: AffinityExecutor
|
||||
protected abstract val serverThread: AffinityExecutor.ServiceAffinityExecutor
|
||||
|
||||
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
|
||||
private val flowFactories = ConcurrentHashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
|
||||
@ -248,7 +249,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
|
||||
flowStarter,
|
||||
servicesForResolution,
|
||||
unfinishedSchedules = busyNodeLatch,
|
||||
serverThread = serverThread,
|
||||
flowLogicRefFactory = flowLogicRefFactory,
|
||||
drainingModePollPeriod = configuration.drainingModePollPeriod,
|
||||
nodeProperties = nodeProperties)
|
||||
@ -385,11 +385,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
|
||||
protected abstract fun myAddresses(): List<NetworkHostAndPort>
|
||||
|
||||
protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager {
|
||||
return StateMachineManagerImpl(
|
||||
return SingleThreadedStateMachineManager(
|
||||
services,
|
||||
checkpointStorage,
|
||||
serverThread,
|
||||
database,
|
||||
newSecureRandom(),
|
||||
busyNodeLatch,
|
||||
cordappLoader.appClassLoader
|
||||
)
|
||||
@ -639,7 +640,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
|
||||
|
||||
protected open fun makeTransactionStorage(database: CordaPersistence, transactionCacheSizeBytes: Long): WritableTransactionStorage = DBTransactionStorage(transactionCacheSizeBytes)
|
||||
private fun makeVaultObservers(schedulerService: SchedulerService, hibernateConfig: HibernateConfiguration, smm: StateMachineManager, schemaService: SchemaService, flowLogicRefFactory: FlowLogicRefFactory) {
|
||||
VaultSoftLockManager.install(services.vaultService, smm)
|
||||
ScheduledActivityObserver.install(services.vaultService, schedulerService, flowLogicRefFactory)
|
||||
HibernateObserver.install(services.vaultService.rawUpdates, hibernateConfig, schemaService)
|
||||
}
|
||||
@ -894,8 +894,8 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
|
||||
}
|
||||
|
||||
internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
|
||||
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> {
|
||||
return serverThread.fetchFrom { smm.startFlow(logic, context) }
|
||||
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture<FlowStateMachine<T>> {
|
||||
return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler)
|
||||
}
|
||||
|
||||
override fun <T> invokeFlowAsync(
|
||||
|
@ -1,42 +1,28 @@
|
||||
package net.corda.node.services.api
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import java.util.stream.Stream
|
||||
|
||||
/**
|
||||
* Thread-safe storage of fiber checkpoints.
|
||||
*/
|
||||
interface CheckpointStorage {
|
||||
|
||||
/**
|
||||
* Add a new checkpoint to the store.
|
||||
*/
|
||||
fun addCheckpoint(checkpoint: Checkpoint)
|
||||
fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>)
|
||||
|
||||
/**
|
||||
* Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist
|
||||
* in the store. Doing so will throw an [IllegalArgumentException].
|
||||
* Remove existing checkpoint from the store.
|
||||
* @return whether the id matched a checkpoint that was removed.
|
||||
*/
|
||||
fun removeCheckpoint(checkpoint: Checkpoint)
|
||||
fun removeCheckpoint(id: StateMachineRunId): Boolean
|
||||
|
||||
/**
|
||||
* Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
|
||||
* The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
|
||||
* Return false from the block to terminate further iteration.
|
||||
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
|
||||
* underlying database connection is closed, so any processing should happen before it is closed.
|
||||
*/
|
||||
fun forEach(block: (Checkpoint) -> Boolean)
|
||||
|
||||
}
|
||||
|
||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||
class Checkpoint(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
|
||||
|
||||
val id: SecureHash get() = serializedFiber.hash
|
||||
|
||||
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
|
||||
|
||||
override fun hashCode(): Int = id.hashCode()
|
||||
|
||||
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
|
||||
fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>>
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.internal.cordapp.CordappProviderInternal
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.network.NetworkMapUpdater
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl
|
||||
@ -135,8 +136,9 @@ interface FlowStarter {
|
||||
/**
|
||||
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
|
||||
* @param context indicates who started the flow, see: [InvocationContext].
|
||||
* @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler]
|
||||
*/
|
||||
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>>
|
||||
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture<FlowStateMachine<T>>
|
||||
|
||||
/**
|
||||
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
|
||||
|
@ -26,10 +26,12 @@ import net.corda.node.MutableClock
|
||||
import net.corda.node.services.api.FlowStarter
|
||||
import net.corda.node.services.api.NodePropertiesStore
|
||||
import net.corda.node.services.api.SchedulerService
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.utilities.PersistentMap
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import org.apache.mina.util.ConcurrentHashSet
|
||||
import org.slf4j.Logger
|
||||
import java.io.Serializable
|
||||
import java.time.Duration
|
||||
@ -61,7 +63,6 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
private val flowStarter: FlowStarter,
|
||||
private val servicesForResolution: ServicesForResolution,
|
||||
private val unfinishedSchedules: ReusableLatch = ReusableLatch(),
|
||||
private val serverThread: Executor,
|
||||
private val flowLogicRefFactory: FlowLogicRefFactory,
|
||||
private val nodeProperties: NodePropertiesStore,
|
||||
private val drainingModePollPeriod: Duration,
|
||||
@ -164,6 +165,10 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
var rescheduled: GuavaSettableFuture<Boolean>? = null
|
||||
}
|
||||
|
||||
// Used to de-duplicate flow starts in case a flow is starting but the corresponding entry hasn't been removed yet
|
||||
// from the database
|
||||
private val startingStateRefs = ConcurrentHashSet<ScheduledStateRef>()
|
||||
|
||||
private val mutex = ThreadBox(InnerState())
|
||||
// We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow.
|
||||
fun start() {
|
||||
@ -173,6 +178,29 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop scheduler service.
|
||||
*/
|
||||
fun stop() {
|
||||
mutex.locked {
|
||||
schedulerTimerExecutor.shutdown()
|
||||
scheduledStatesQueue.clear()
|
||||
scheduledStates.clear()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resume scheduler service after having called [stop].
|
||||
*/
|
||||
fun resume() {
|
||||
mutex.locked {
|
||||
schedulerTimerExecutor = Executors.newSingleThreadExecutor()
|
||||
scheduledStates.putAll(createMap())
|
||||
scheduledStatesQueue.addAll(scheduledStates.values)
|
||||
rescheduleWakeUp()
|
||||
}
|
||||
}
|
||||
|
||||
override fun scheduleStateActivity(action: ScheduledStateRef) {
|
||||
log.trace { "Schedule $action" }
|
||||
val previousState = scheduledStates[action.ref]
|
||||
@ -181,7 +209,7 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
val previousEarliest = scheduledStatesQueue.peek()
|
||||
scheduledStatesQueue.remove(previousState)
|
||||
scheduledStatesQueue.add(action)
|
||||
if (previousState == null) {
|
||||
if (previousState == null && action !in startingStateRefs) {
|
||||
unfinishedSchedules.countUp()
|
||||
}
|
||||
|
||||
@ -212,7 +240,7 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
}
|
||||
}
|
||||
|
||||
private val schedulerTimerExecutor = Executors.newSingleThreadExecutor()
|
||||
private var schedulerTimerExecutor = Executors.newSingleThreadExecutor()
|
||||
/**
|
||||
* This method first cancels the [java.util.concurrent.Future] for any pending action so that the
|
||||
* [awaitWithDeadline] used below drops through without running the action. We then create a new
|
||||
@ -254,25 +282,41 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
schedulerTimerExecutor.join()
|
||||
}
|
||||
|
||||
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler {
|
||||
override fun insideDatabaseTransaction() {
|
||||
scheduledStates.remove(scheduledState.ref)
|
||||
}
|
||||
|
||||
override fun afterDatabaseTransaction() {
|
||||
startingStateRefs.remove(scheduledState)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}($scheduledState)"
|
||||
}
|
||||
}
|
||||
|
||||
private fun onTimeReached(scheduledState: ScheduledStateRef) {
|
||||
serverThread.execute {
|
||||
var flowName: String? = "(unknown)"
|
||||
try {
|
||||
database.transaction {
|
||||
val scheduledFlow = getScheduledFlow(scheduledState)
|
||||
if (scheduledFlow != null) {
|
||||
flowName = scheduledFlow.javaClass.name
|
||||
// TODO refactor the scheduler to store and propagate the original invocation context
|
||||
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
|
||||
val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture }
|
||||
future.then {
|
||||
unfinishedSchedules.countDown()
|
||||
}
|
||||
var flowName: String? = "(unknown)"
|
||||
try {
|
||||
// We need to check this before the database transaction, otherwise there is a subtle race between a
|
||||
// doubly-reached deadline and the removal from [startingStateRefs].
|
||||
if (scheduledState !in startingStateRefs) {
|
||||
val scheduledFlow = database.transaction { getScheduledFlow(scheduledState) }
|
||||
if (scheduledFlow != null) {
|
||||
startingStateRefs.add(scheduledState)
|
||||
flowName = scheduledFlow.javaClass.name
|
||||
// TODO refactor the scheduler to store and propagate the original invocation context
|
||||
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
|
||||
val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState)
|
||||
val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture }
|
||||
future.then {
|
||||
unfinishedSchedules.countDown()
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,7 +348,6 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
}
|
||||
else -> {
|
||||
log.trace { "Scheduler starting FlowLogic $flowLogic" }
|
||||
scheduledStates.remove(scheduledState.ref)
|
||||
scheduledStatesQueue.remove(scheduledState)
|
||||
flowLogic
|
||||
}
|
||||
@ -328,4 +371,4 @@ class NodeSchedulerService(private val clock: CordaClock,
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.messaging.SingleMessageRecipient
|
||||
@ -8,8 +9,8 @@ import net.corda.core.node.services.PartyInfo
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
/**
|
||||
@ -35,7 +36,7 @@ interface MessagingService {
|
||||
*
|
||||
* @param topic identifier for the topic to listen for messages arriving on.
|
||||
*/
|
||||
fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
|
||||
fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration
|
||||
|
||||
/**
|
||||
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
|
||||
@ -66,8 +67,7 @@ interface MessagingService {
|
||||
message: Message,
|
||||
target: MessageRecipients,
|
||||
retryId: Long? = null,
|
||||
sequenceKey: Any = target,
|
||||
additionalHeaders: Map<String, String> = emptyMap()
|
||||
sequenceKey: Any = target
|
||||
)
|
||||
|
||||
/** A message with a target and sequenceKey specified. */
|
||||
@ -97,7 +97,7 @@ interface MessagingService {
|
||||
* @param additionalProperties optional additional message headers.
|
||||
* @param topic identifier for the topic the message is sent to.
|
||||
*/
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map<String, String> = emptyMap()): Message
|
||||
|
||||
/** Given information about either a specific node or a service returns its corresponding address */
|
||||
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
|
||||
@ -106,9 +106,8 @@ interface MessagingService {
|
||||
val myAddress: SingleMessageRecipient
|
||||
}
|
||||
|
||||
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null)
|
||||
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId)
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap())
|
||||
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
|
||||
|
||||
interface MessageHandlerRegistration
|
||||
|
||||
@ -127,7 +126,9 @@ interface Message {
|
||||
val topic: String
|
||||
val data: ByteSequence
|
||||
val debugTimestamp: Instant
|
||||
val uniqueMessageId: String
|
||||
val uniqueMessageId: DeduplicationId
|
||||
val senderUUID: String?
|
||||
val additionalHeaders: Map<String, String>
|
||||
}
|
||||
|
||||
// TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that.
|
||||
@ -138,6 +139,10 @@ interface ReceivedMessage : Message {
|
||||
val peer: CordaX500Name
|
||||
/** Platform version of the sender's node. */
|
||||
val platformVersion: Int
|
||||
/** Sequence number of message with respect to senderUUID */
|
||||
val senderSeqNo: Long?
|
||||
/** True if a flow session init message */
|
||||
val isSessionInit: Boolean
|
||||
}
|
||||
|
||||
/** A singleton that's useful for validating topic strings */
|
||||
@ -147,3 +152,29 @@ object TopicStringValidator {
|
||||
fun check(tag: String) = require(regex.matcher(tag).matches())
|
||||
}
|
||||
|
||||
/**
|
||||
* This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done
|
||||
* using two hooks that are called from the event processor, one called from the database transaction committing the
|
||||
* side-effect caused by the event, and another one called after the transaction has committed successfully.
|
||||
*
|
||||
* For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later
|
||||
* deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries.
|
||||
*
|
||||
* We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the
|
||||
* to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping.
|
||||
*/
|
||||
interface DeduplicationHandler {
|
||||
/**
|
||||
* This will be run inside a database transaction that commits the side-effect of the event, allowing the
|
||||
* implementor to persist the event delivery fact atomically with the side-effect.
|
||||
*/
|
||||
fun insideDatabaseTransaction()
|
||||
|
||||
/**
|
||||
* This will be run strictly after the side-effect has been committed successfully and may be used for
|
||||
* cleanup/acknowledgement/stopping of retries.
|
||||
*/
|
||||
fun afterDatabaseTransaction()
|
||||
}
|
||||
|
||||
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit
|
||||
|
@ -0,0 +1,86 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.strands.SettableFuture
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.VersionInfo
|
||||
import net.corda.node.services.statemachine.FlowMessagingImpl
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQDuplicateIdException
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientProducer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import java.util.concurrent.ArrayBlockingQueue
|
||||
import java.util.concurrent.ExecutionException
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
interface AddressToArtemisQueueResolver {
|
||||
/**
|
||||
* Resolves a [MessageRecipients] to an Artemis queue name, creating the underlying queue if needed.
|
||||
*/
|
||||
fun resolveTargetToArtemisQueue(address: MessageRecipients): String
|
||||
}
|
||||
|
||||
/**
|
||||
* The [MessagingExecutor] is responsible for handling send and acknowledge jobs. It batches them using a bounded
|
||||
* blocking queue, submits the jobs asynchronously and then waits for them to flush using [ClientSession.commit].
|
||||
* Note that even though we buffer in theory this shouldn't increase latency as the executor is immediately woken up if
|
||||
* it was waiting. The number of jobs in the queue is only ever greater than 1 if the commit takes a long time.
|
||||
*/
|
||||
class MessagingExecutor(
|
||||
val session: ClientSession,
|
||||
val producer: ClientProducer,
|
||||
val versionInfo: VersionInfo,
|
||||
val resolver: AddressToArtemisQueueResolver,
|
||||
val ourSenderUUID: String
|
||||
) {
|
||||
private val cordaVendor = SimpleString(versionInfo.vendor)
|
||||
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
|
||||
private val ourSenderSeqNo = AtomicLong()
|
||||
|
||||
private companion object {
|
||||
val log = contextLogger()
|
||||
val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
|
||||
}
|
||||
|
||||
fun send(message: Message, target: MessageRecipients) {
|
||||
val mqAddress = resolver.resolveTargetToArtemisQueue(target)
|
||||
val artemisMessage = cordaToArtemisMessage(message)
|
||||
log.trace {
|
||||
"Send to: $mqAddress topic: ${message.topic} " +
|
||||
"sessionID: ${message.topic} id: ${message.uniqueMessageId}"
|
||||
}
|
||||
producer.send(SimpleString(mqAddress), artemisMessage)
|
||||
}
|
||||
|
||||
fun acknowledge(message: ClientMessage) {
|
||||
message.individualAcknowledge()
|
||||
}
|
||||
|
||||
internal fun cordaToArtemisMessage(message: Message): ClientMessage? {
|
||||
return session.createMessage(true).apply {
|
||||
putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor)
|
||||
putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion)
|
||||
putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion)
|
||||
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
|
||||
writeBodyBufferBytes(message.data.bytes)
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
|
||||
// If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut.
|
||||
if (ourSenderUUID == message.senderUUID) {
|
||||
putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID))
|
||||
putLongProperty(P2PMessagingHeaders.senderSeqNo, ourSenderSeqNo.getAndIncrement())
|
||||
}
|
||||
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
|
||||
if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) {
|
||||
putLongProperty(org.apache.activemq.artemis.api.core.Message.HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
|
||||
}
|
||||
message.additionalHeaders.forEach { key, value -> putStringProperty(key, value) }
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,108 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.utilities.AppendOnlyPersistentMap
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import java.io.Serializable
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Id
|
||||
|
||||
/**
|
||||
* Encapsulate the de-duplication logic.
|
||||
*/
|
||||
class P2PMessageDeduplicator(private val database: CordaPersistence) {
|
||||
val ourSenderUUID = UUID.randomUUID().toString()
|
||||
|
||||
// A temporary in-memory set of deduplication IDs and associated high water mark details.
|
||||
// When we receive a message we don't persist the ID immediately,
|
||||
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
|
||||
// redeliver messages to the same consumer if they weren't ACKed.
|
||||
private val beingProcessedMessages = ConcurrentHashMap<DeduplicationId, MessageMeta>()
|
||||
private val processedMessages = createProcessedMessages()
|
||||
|
||||
private fun createProcessedMessages(): AppendOnlyPersistentMap<DeduplicationId, MessageMeta, ProcessedMessage, String> {
|
||||
return AppendOnlyPersistentMap(
|
||||
toPersistentEntityKey = { it.toString },
|
||||
fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) },
|
||||
toPersistentEntity = { key: DeduplicationId, value: MessageMeta ->
|
||||
ProcessedMessage().apply {
|
||||
id = key.toString
|
||||
insertionTime = value.insertionTime
|
||||
hash = value.senderHash
|
||||
seqNo = value.senderSeqNo
|
||||
}
|
||||
},
|
||||
persistentEntityClass = ProcessedMessage::class.java
|
||||
)
|
||||
}
|
||||
|
||||
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages }
|
||||
|
||||
// We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache.
|
||||
private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString()
|
||||
|
||||
/**
|
||||
* @return true if we have seen this message before.
|
||||
*/
|
||||
fun isDuplicate(msg: ReceivedMessage): Boolean {
|
||||
if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) {
|
||||
return true
|
||||
}
|
||||
return isDuplicateInDatabase(msg)
|
||||
}
|
||||
|
||||
/**
|
||||
* Called the first time we encounter [deduplicationId].
|
||||
*/
|
||||
fun signalMessageProcessStart(msg: ReceivedMessage) {
|
||||
val receivedSenderUUID = msg.senderUUID
|
||||
val receivedSenderSeqNo = msg.senderSeqNo
|
||||
// We don't want a mix of nulls and values so we ensure that here.
|
||||
val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null
|
||||
val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
|
||||
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo)
|
||||
}
|
||||
|
||||
/**
|
||||
* Called inside a DB transaction to persist [deduplicationId].
|
||||
*/
|
||||
fun persistDeduplicationId(deduplicationId: DeduplicationId) {
|
||||
processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!!
|
||||
}
|
||||
|
||||
/**
|
||||
* Called after the DB transaction persisting [deduplicationId] committed.
|
||||
* Any subsequent redelivery will be deduplicated using the DB.
|
||||
*/
|
||||
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) {
|
||||
beingProcessedMessages.remove(deduplicationId)
|
||||
}
|
||||
|
||||
@Entity
|
||||
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
|
||||
class ProcessedMessage(
|
||||
@Id
|
||||
@Column(name = "message_id", length = 64)
|
||||
var id: String = "",
|
||||
|
||||
@Column(name = "insertion_time")
|
||||
var insertionTime: Instant = Instant.now(),
|
||||
|
||||
@Column(name = "sender", length = 64)
|
||||
var hash: String? = "",
|
||||
|
||||
@Column(name = "sequence_number")
|
||||
var seqNo: Long? = null
|
||||
) : Serializable
|
||||
|
||||
private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?)
|
||||
|
||||
private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean)
|
||||
}
|
@ -1,8 +1,11 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import net.corda.core.crypto.toStringShort
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.messaging.SingleMessageRecipient
|
||||
@ -21,9 +24,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer
|
||||
import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex
|
||||
import net.corda.node.services.api.NetworkMapCacheInternal
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.statemachine.StateMachineManagerImpl
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.AppendOnlyPersistentMap
|
||||
import net.corda.node.utilities.PersistentMap
|
||||
import net.corda.nodeapi.ArtemisTcpTransport
|
||||
import net.corda.nodeapi.ConnectionDirection
|
||||
@ -38,7 +40,8 @@ import net.corda.nodeapi.internal.bridging.BridgeEntry
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
|
||||
import org.apache.activemq.artemis.api.core.Message.*
|
||||
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
|
||||
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
|
||||
import org.apache.activemq.artemis.api.core.RoutingType
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.*
|
||||
@ -50,15 +53,16 @@ import java.io.Serializable
|
||||
import java.security.PublicKey
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import java.util.concurrent.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.ScheduledFuture
|
||||
import java.util.concurrent.TimeUnit
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Id
|
||||
import javax.persistence.Lob
|
||||
|
||||
// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox
|
||||
|
||||
/**
|
||||
* This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product.
|
||||
* Artemis is a message queue broker and here we run a client connecting to the specified broker instance
|
||||
@ -85,7 +89,7 @@ import javax.persistence.Lob
|
||||
* @param maxMessageSize A bound applied to the message size.
|
||||
*/
|
||||
@ThreadSafe
|
||||
class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
class P2PMessagingClient(val config: NodeConfiguration,
|
||||
private val versionInfo: VersionInfo,
|
||||
private val serverAddress: NetworkHostAndPort,
|
||||
private val myIdentity: PublicKey,
|
||||
@ -97,26 +101,11 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
private val maxMessageSize: Int,
|
||||
private val isDrainingModeOn: () -> Boolean,
|
||||
private val drainingModeWasChangedEvents: Observable<Pair<Boolean, Boolean>>
|
||||
) : SingletonSerializeAsToken(), MessagingService, AutoCloseable {
|
||||
) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver, AutoCloseable {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
|
||||
private const val messageMaxRetryCount: Int = 3
|
||||
|
||||
fun createProcessedMessage(): AppendOnlyPersistentMap<String, Instant, ProcessedMessage, String> {
|
||||
return AppendOnlyPersistentMap(
|
||||
toPersistentEntityKey = { it },
|
||||
fromPersistentEntity = { Pair(it.uuid, it.insertionTime) },
|
||||
toPersistentEntity = { key: String, value: Instant ->
|
||||
ProcessedMessage().apply {
|
||||
uuid = key
|
||||
insertionTime = value
|
||||
}
|
||||
},
|
||||
persistentEntityClass = ProcessedMessage::class.java
|
||||
)
|
||||
}
|
||||
|
||||
fun createMessageToRedeliver(): PersistentMap<Long, Pair<Message, MessageRecipients>, RetryMessage, Long> {
|
||||
return PersistentMap(
|
||||
toPersistentEntityKey = { it },
|
||||
@ -137,7 +126,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
)
|
||||
}
|
||||
|
||||
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message {
|
||||
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
|
||||
override val debugTimestamp: Instant = Instant.now()
|
||||
override fun toString() = "$topic#${String(data.bytes)}"
|
||||
}
|
||||
@ -165,32 +154,17 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
|
||||
|
||||
/** A registration to handle messages of different types */
|
||||
data class Handler(val topic: String,
|
||||
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
|
||||
|
||||
private val cordaVendor = SimpleString(versionInfo.vendor)
|
||||
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
|
||||
/** An executor for sending messages */
|
||||
private val messagingExecutor = AffinityExecutor.ServiceAffinityExecutor("Messaging ${myIdentity.toStringShort()}", 1)
|
||||
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
|
||||
|
||||
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
|
||||
private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong()
|
||||
private val state = ThreadBox(InnerState())
|
||||
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
|
||||
private val handlers = CopyOnWriteArrayList<Handler>()
|
||||
|
||||
private val processedMessages = createProcessedMessage()
|
||||
private val handlers = ConcurrentHashMap<String, MessageHandler>()
|
||||
|
||||
@Entity
|
||||
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
|
||||
class ProcessedMessage(
|
||||
@Id
|
||||
@Column(name = "message_id", length = 64)
|
||||
var uuid: String = "",
|
||||
|
||||
@Column(name = "insertion_time")
|
||||
var insertionTime: Instant = Instant.now()
|
||||
) : Serializable
|
||||
private val deduplicator = P2PMessageDeduplicator(database)
|
||||
internal var messagingExecutor: MessagingExecutor? = null
|
||||
|
||||
@Entity
|
||||
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
|
||||
@ -246,6 +220,14 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
inboxes.forEach { createQueueIfAbsent(it, producerSession!!) }
|
||||
p2pConsumer = P2PMessagingConsumer(inboxes, createNewSession, isDrainingModeOn, drainingModeWasChangedEvents)
|
||||
|
||||
messagingExecutor = MessagingExecutor(
|
||||
producerSession!!,
|
||||
producer!!,
|
||||
versionInfo,
|
||||
this@P2PMessagingClient,
|
||||
ourSenderUUID = deduplicator.ourSenderUUID
|
||||
)
|
||||
|
||||
registerBridgeControl(bridgeSession!!, inboxes.toList())
|
||||
enumerateBridges(bridgeSession!!, inboxes.toList())
|
||||
}
|
||||
@ -255,7 +237,9 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
|
||||
private fun InnerState.registerBridgeControl(session: ClientSession, inboxes: List<String>) {
|
||||
val bridgeNotifyQueue = "$BRIDGE_NOTIFY.${myIdentity.toStringShort()}"
|
||||
session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue)
|
||||
if (!session.queueQuery(SimpleString(bridgeNotifyQueue)).isExists) {
|
||||
session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue)
|
||||
}
|
||||
val bridgeConsumer = session.createConsumer(bridgeNotifyQueue)
|
||||
bridgeNotifyConsumer = bridgeConsumer
|
||||
bridgeConsumer.setMessageHandler { msg ->
|
||||
@ -273,7 +257,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
networkChangeSubscription = networkMap.changed.subscribe { updateBridgesOnNetworkChange(it) }
|
||||
}
|
||||
|
||||
private fun sendBridgeControl(message: BridgeControl) {
|
||||
private fun sendBridgeControl(message: BridgeControl) {
|
||||
state.locked {
|
||||
val controlPacket = message.serialize(context = SerializationDefaults.P2P_CONTEXT).bytes
|
||||
val artemisMessage = producerSession!!.createMessage(false)
|
||||
@ -343,38 +327,35 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
|
||||
private fun resumeMessageRedelivery() {
|
||||
messagesToRedeliver.forEach { retryId, (message, target) ->
|
||||
sendInternal(message, target, retryId)
|
||||
send(message, target, retryId)
|
||||
}
|
||||
}
|
||||
|
||||
private val shutdownLatch = CountDownLatch(1)
|
||||
|
||||
var runningFuture = openFuture<Unit>()
|
||||
|
||||
/**
|
||||
* Starts the p2p event loop: this method only returns once [stop] has been called.
|
||||
*/
|
||||
fun run() {
|
||||
|
||||
val latch = CountDownLatch(1)
|
||||
try {
|
||||
val consumer = state.locked {
|
||||
check(started) { "start must be called first" }
|
||||
check(!running) { "run can't be called twice" }
|
||||
running = true
|
||||
runningFuture.set(Unit)
|
||||
// If it's null, it means we already called stop, so return immediately.
|
||||
if (p2pConsumer == null) {
|
||||
return
|
||||
}
|
||||
eventsSubscription = p2pConsumer!!.messages
|
||||
.doOnError { error -> throw error }
|
||||
.doOnNext { artemisMessage ->
|
||||
val receivedMessage = artemisToCordaMessage(artemisMessage)
|
||||
receivedMessage?.let {
|
||||
deliver(it)
|
||||
}
|
||||
artemisMessage.acknowledge()
|
||||
}
|
||||
.doOnNext { message -> deliver(message) }
|
||||
// this `run()` method is semantically meant to block until the message consumption runs, hence the latch here
|
||||
.doOnCompleted(latch::countDown)
|
||||
.doOnError { error -> throw error }
|
||||
.subscribe()
|
||||
p2pConsumer!!
|
||||
}
|
||||
@ -391,10 +372,13 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
|
||||
val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) }
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) }
|
||||
log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid")
|
||||
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) }
|
||||
val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID)
|
||||
val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null
|
||||
val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE
|
||||
log.trace { "Received message from: ${message.address} user: $user topic: $topic id: $uniqueMessageId senderUUID: $receivedSenderUUID senderSeqNo: $receivedSenderSeqNo isSessionInit: $isSessionInit" }
|
||||
|
||||
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message)
|
||||
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uniqueMessageId, receivedSenderUUID, receivedSenderSeqNo, isSessionInit, message)
|
||||
} catch (e: Exception) {
|
||||
log.error("Unable to process message, ignoring it: $message", e)
|
||||
return null
|
||||
@ -409,52 +393,55 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
private class ArtemisReceivedMessage(override val topic: String,
|
||||
override val peer: CordaX500Name,
|
||||
override val platformVersion: Int,
|
||||
override val uniqueMessageId: String,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val senderUUID: String?,
|
||||
override val senderSeqNo: Long?,
|
||||
override val isSessionInit: Boolean,
|
||||
private val message: ClientMessage) : ReceivedMessage {
|
||||
override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) }
|
||||
override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp)
|
||||
override val additionalHeaders: Map<String, String> = emptyMap()
|
||||
override fun toString() = "$topic#$data"
|
||||
}
|
||||
|
||||
private fun deliver(msg: ReceivedMessage): Boolean {
|
||||
state.checkNotLocked()
|
||||
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
|
||||
// or removed whilst the filter is executing will not affect anything.
|
||||
val deliverTo = handlers.filter { it.topic.isBlank() || it.topic == msg.topic }
|
||||
try {
|
||||
// This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will
|
||||
// be slow, and Artemis can handle that case intelligently. We don't just invoke the handler
|
||||
// directly in order to ensure that we have the features of the AffinityExecutor class throughout
|
||||
// the bulk of the codebase and other non-messaging jobs can be scheduled onto the server executor
|
||||
// easily.
|
||||
//
|
||||
// Note that handlers may re-enter this class. We aren't holding any locks and methods like
|
||||
// start/run/stop have re-entrancy assertions at the top, so it is OK.
|
||||
nodeExecutor.fetchFrom {
|
||||
database.transaction {
|
||||
if (msg.uniqueMessageId in processedMessages) {
|
||||
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" }
|
||||
} else {
|
||||
if (deliverTo.isEmpty()) {
|
||||
// TODO: Implement dead letter queue, and send it there.
|
||||
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
|
||||
} else {
|
||||
callHandlers(msg, deliverTo)
|
||||
}
|
||||
// TODO We will at some point need to decide a trimming policy for the id's
|
||||
processedMessages[msg.uniqueMessageId] = Instant.now()
|
||||
}
|
||||
}
|
||||
internal fun deliver(artemisMessage: ClientMessage) {
|
||||
artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
|
||||
if (!deduplicator.isDuplicate(cordaMessage)) {
|
||||
deduplicator.signalMessageProcessStart(cordaMessage)
|
||||
deliver(cordaMessage, artemisMessage)
|
||||
} else {
|
||||
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" }
|
||||
artemisMessage.individualAcknowledge()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private fun callHandlers(msg: ReceivedMessage, deliverTo: List<Handler>) {
|
||||
for (handler in deliverTo) {
|
||||
handler.callback(msg, handler)
|
||||
private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) {
|
||||
state.checkNotLocked()
|
||||
val deliverTo = handlers[msg.topic]
|
||||
if (deliverTo != null) {
|
||||
try {
|
||||
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg))
|
||||
} catch (e: Exception) {
|
||||
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
|
||||
}
|
||||
} else {
|
||||
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
|
||||
}
|
||||
}
|
||||
|
||||
inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler {
|
||||
override fun insideDatabaseTransaction() {
|
||||
deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId)
|
||||
}
|
||||
|
||||
override fun afterDatabaseTransaction() {
|
||||
deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId)
|
||||
messagingExecutor!!.acknowledge(artemisMessage)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})"
|
||||
}
|
||||
}
|
||||
|
||||
@ -470,6 +457,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
check(started)
|
||||
val prevRunning = running
|
||||
running = false
|
||||
runningFuture = openFuture()
|
||||
networkChangeSubscription?.unsubscribe()
|
||||
require(p2pConsumer != null, { "stop can't be called twice" })
|
||||
require(producer != null, { "stop can't be called twice" })
|
||||
@ -507,75 +495,42 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
|
||||
override fun close() = stop()
|
||||
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
|
||||
sendInternal(message, target, retryId, additionalHeaders)
|
||||
}
|
||||
|
||||
private fun sendInternal(message: Message, target: MessageRecipients, retryId: Long?, additionalHeaders: Map<String, String> = emptyMap()) {
|
||||
// We have to perform sending on a different thread pool, since using the same pool for messaging and
|
||||
// fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals.
|
||||
messagingExecutor.fetchFrom {
|
||||
state.locked {
|
||||
val mqAddress = getMQAddress(target)
|
||||
val artemisMessage = producerSession!!.createMessage(true).apply {
|
||||
putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor)
|
||||
putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion)
|
||||
putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion)
|
||||
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
|
||||
writeBodyBufferBytes(message.data.bytes)
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId))
|
||||
|
||||
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
|
||||
if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) {
|
||||
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
|
||||
}
|
||||
additionalHeaders.forEach { key, value -> putStringProperty(key, value) }
|
||||
}
|
||||
log.trace {
|
||||
"Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}"
|
||||
}
|
||||
sendMessage(mqAddress, artemisMessage)
|
||||
retryId?.let {
|
||||
database.transaction {
|
||||
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
|
||||
}
|
||||
scheduledMessageRedeliveries[it] = messagingExecutor.schedule({
|
||||
sendWithRetry(0, mqAddress, artemisMessage, it)
|
||||
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
|
||||
|
||||
}
|
||||
@Suspendable
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
|
||||
messagingExecutor!!.send(message, target)
|
||||
retryId?.let {
|
||||
database.transaction {
|
||||
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
|
||||
}
|
||||
scheduledMessageRedeliveries[it] = nodeExecutor.schedule({
|
||||
sendWithRetry(0, message, target, retryId)
|
||||
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
|
||||
for ((message, target, retryId, sequenceKey) in addressedMessages) {
|
||||
send(message, target, retryId, sequenceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) {
|
||||
fun ClientMessage.randomiseDuplicateId() {
|
||||
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
|
||||
}
|
||||
|
||||
private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) {
|
||||
log.trace { "Attempting to retry #$retryCount message delivery for $retryId" }
|
||||
if (retryCount >= messageMaxRetryCount) {
|
||||
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $address")
|
||||
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target")
|
||||
scheduledMessageRedeliveries.remove(retryId)
|
||||
return
|
||||
}
|
||||
|
||||
message.randomiseDuplicateId()
|
||||
|
||||
state.locked {
|
||||
log.trace { "Retry #$retryCount sending message $message to $address for $retryId" }
|
||||
sendMessage(address, message)
|
||||
val messageWithRetryCount = object : Message by message {
|
||||
override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount")
|
||||
}
|
||||
|
||||
scheduledMessageRedeliveries[retryId] = messagingExecutor.schedule({
|
||||
sendWithRetry(retryCount + 1, address, message, retryId)
|
||||
messagingExecutor!!.send(messageWithRetryCount, target)
|
||||
|
||||
scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({
|
||||
sendWithRetry(retryCount + 1, message, target, retryId)
|
||||
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
|
||||
}
|
||||
|
||||
@ -590,18 +545,14 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
private fun Pair<ClientMessage, ReceivedMessage?>.deliver() = deliver(second!!)
|
||||
private fun Pair<ClientMessage, ReceivedMessage?>.acknowledge() = first.acknowledge()
|
||||
|
||||
private fun getMQAddress(target: MessageRecipients): String {
|
||||
return if (target == myAddress) {
|
||||
override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
|
||||
return if (address == myAddress) {
|
||||
// If we are sending to ourselves then route the message directly to our P2P queue.
|
||||
RemoteInboxAddress(myIdentity).queueName
|
||||
} else {
|
||||
// Otherwise we send the message to an internal queue for the target residing on our broker. It's then the
|
||||
// broker's job to route the message to the target's P2P queue.
|
||||
val internalTargetQueue = (target as? ArtemisAddress)?.queueName
|
||||
?: throw IllegalArgumentException("Not an Artemis address")
|
||||
val internalTargetQueue = (address as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address")
|
||||
state.locked {
|
||||
createQueueIfAbsent(internalTargetQueue, producerSession!!)
|
||||
}
|
||||
@ -630,24 +581,26 @@ class P2PMessagingClient(private val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
override fun addMessageHandler(topic: String,
|
||||
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
|
||||
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
|
||||
require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
|
||||
val handler = Handler(topic, callback)
|
||||
handlers.add(handler)
|
||||
return handler
|
||||
handlers.compute(topic) { _, handler ->
|
||||
if (handler != null) {
|
||||
throw IllegalStateException("Cannot add another acking handler for $topic, there is already an acking one")
|
||||
}
|
||||
callback
|
||||
}
|
||||
return HandlerRegistration(topic, callback)
|
||||
}
|
||||
|
||||
override fun removeMessageHandler(registration: MessageHandlerRegistration) {
|
||||
handlers.remove(registration)
|
||||
registration as HandlerRegistration
|
||||
handlers.remove(registration.topic)
|
||||
}
|
||||
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
|
||||
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId)
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders)
|
||||
}
|
||||
|
||||
// TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port)
|
||||
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
|
||||
return when (partyInfo) {
|
||||
is PartyInfo.SingleNode -> NodeAddress(partyInfo.party.owningKey, partyInfo.addresses.single())
|
||||
@ -720,4 +673,4 @@ private fun ReactiveArtemisConsumer.switchTo(other: ReactiveArtemisConsumer) {
|
||||
!other.started -> other.start()
|
||||
!other.connected -> other.connect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,16 @@
|
||||
package net.corda.node.services.persistence
|
||||
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.node.services.api.Checkpoint
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import net.corda.nodeapi.internal.persistence.currentDBSession
|
||||
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.util.*
|
||||
import java.util.stream.Stream
|
||||
import java.io.Serializable
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
@ -16,6 +21,7 @@ import javax.persistence.Lob
|
||||
* Simple checkpoint key value storage in DB.
|
||||
*/
|
||||
class DBCheckpointStorage : CheckpointStorage {
|
||||
val log = LoggerFactory.getLogger(this::class.java)
|
||||
|
||||
@Entity
|
||||
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints")
|
||||
@ -29,32 +35,30 @@ class DBCheckpointStorage : CheckpointStorage {
|
||||
var checkpoint: ByteArray = EMPTY_BYTE_ARRAY
|
||||
) : Serializable
|
||||
|
||||
override fun addCheckpoint(checkpoint: Checkpoint) {
|
||||
currentDBSession().save(DBCheckpoint().apply {
|
||||
checkpointId = checkpoint.id.toString()
|
||||
this.checkpoint = checkpoint.serializedFiber.bytes // XXX: Is copying the byte array necessary?
|
||||
override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) {
|
||||
currentDBSession().saveOrUpdate(DBCheckpoint().apply {
|
||||
checkpointId = id.uuid.toString()
|
||||
this.checkpoint = checkpoint.bytes
|
||||
log.debug { "Checkpoint $checkpointId, size=${this.checkpoint.size}" }
|
||||
})
|
||||
}
|
||||
|
||||
override fun removeCheckpoint(checkpoint: Checkpoint) {
|
||||
override fun removeCheckpoint(id: StateMachineRunId): Boolean {
|
||||
val session = currentDBSession()
|
||||
val criteriaBuilder = session.criteriaBuilder
|
||||
val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java)
|
||||
val root = delete.from(DBCheckpoint::class.java)
|
||||
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), checkpoint.id.toString()))
|
||||
session.createQuery(delete).executeUpdate()
|
||||
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), id.uuid.toString()))
|
||||
return session.createQuery(delete).executeUpdate() > 0
|
||||
}
|
||||
|
||||
override fun forEach(block: (Checkpoint) -> Boolean) {
|
||||
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> {
|
||||
val session = currentDBSession()
|
||||
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
|
||||
val root = criteriaQuery.from(DBCheckpoint::class.java)
|
||||
criteriaQuery.select(root)
|
||||
for (row in session.createQuery(criteriaQuery).resultList) {
|
||||
val checkpoint = Checkpoint(SerializedBytes(row.checkpoint))
|
||||
if (!block(checkpoint)) {
|
||||
break
|
||||
}
|
||||
return session.createQuery(criteriaQuery).stream().map {
|
||||
StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes<Checkpoint>(it.checkpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,15 +1,19 @@
|
||||
package net.corda.node.services.persistence
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.TransactionSignature
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.concurrent.doneFuture
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.toFuture
|
||||
import net.corda.core.transactions.CoreTransaction
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.node.services.api.WritableTransactionStorage
|
||||
import net.corda.node.utilities.*
|
||||
import net.corda.node.utilities.AppendOnlyPersistentMapBase
|
||||
import net.corda.node.utilities.WeightBasedAppendOnlyPersistentMap
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
|
||||
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
|
||||
@ -92,7 +96,18 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
|
||||
|
||||
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
|
||||
return txStorage.locked {
|
||||
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
|
||||
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updates.bufferUntilSubscribed().wrapWithDatabaseTransaction())
|
||||
}
|
||||
}
|
||||
|
||||
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
|
||||
return txStorage.locked {
|
||||
val existingTransaction = get(id)
|
||||
if (existingTransaction == null) {
|
||||
updates.filter { it.id == id }.toFuture()
|
||||
} else {
|
||||
doneFuture(existingTransaction.toSignedTx())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ import net.corda.node.services.api.SchemaService.SchemaOptions
|
||||
import net.corda.node.services.events.NodeSchedulerService
|
||||
import net.corda.node.services.identity.PersistentIdentityService
|
||||
import net.corda.node.services.keys.PersistentKeyManagementService
|
||||
import net.corda.node.services.messaging.P2PMessageDeduplicator
|
||||
import net.corda.node.services.messaging.P2PMessagingClient
|
||||
import net.corda.node.services.persistence.DBCheckpointStorage
|
||||
import net.corda.node.services.persistence.DBTransactionMappingStorage
|
||||
@ -43,7 +44,7 @@ class NodeSchemaService(extraSchemas: Set<MappedSchema> = emptySet(), includeNot
|
||||
PersistentKeyManagementService.PersistentKey::class.java,
|
||||
NodeSchedulerService.PersistentScheduledState::class.java,
|
||||
NodeAttachmentService.DBAttachment::class.java,
|
||||
P2PMessagingClient.ProcessedMessage::class.java,
|
||||
P2PMessageDeduplicator.ProcessedMessage::class.java,
|
||||
P2PMessagingClient.RetryMessage::class.java,
|
||||
PersistentIdentityService.PersistentIdentity::class.java,
|
||||
PersistentIdentityService.PersistentIdentityNames::class.java,
|
||||
|
@ -0,0 +1,138 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowAsyncOperation
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* [Action]s are reified IO actions to execute as part of state machine transitions.
|
||||
*/
|
||||
sealed class Action {
|
||||
|
||||
/**
|
||||
* Track a transaction hash and notify the state machine once the corresponding transaction has committed.
|
||||
*/
|
||||
data class TrackTransaction(val hash: SecureHash) : Action()
|
||||
|
||||
/**
|
||||
* Send an initial session message to [party].
|
||||
*/
|
||||
data class SendInitial(
|
||||
val party: Party,
|
||||
val initialise: InitialSessionMessage,
|
||||
val deduplicationId: DeduplicationId
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
* Send a session message to a [peerParty] with which we have an established session.
|
||||
*/
|
||||
data class SendExisting(
|
||||
val peerParty: Party,
|
||||
val message: ExistingSessionMessage,
|
||||
val deduplicationId: DeduplicationId
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
* Persist the specified [checkpoint].
|
||||
*/
|
||||
data class PersistCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) : Action()
|
||||
|
||||
/**
|
||||
* Remove the checkpoint corresponding to [id].
|
||||
*/
|
||||
data class RemoveCheckpoint(val id: StateMachineRunId) : Action()
|
||||
|
||||
/**
|
||||
* Persist the deduplication facts of [deduplicationHandlers].
|
||||
*/
|
||||
data class PersistDeduplicationFacts(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
|
||||
|
||||
/**
|
||||
* Acknowledge messages in [deduplicationHandlers].
|
||||
*/
|
||||
data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
|
||||
|
||||
/**
|
||||
* Propagate [errorMessages] to [sessions].
|
||||
* @param sessions a map from source session IDs to initiated sessions.
|
||||
*/
|
||||
data class PropagateErrors(
|
||||
val errorMessages: List<ErrorSessionMessage>,
|
||||
val sessions: List<SessionState.Initiated>
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
* Create a session binding from [sessionId] to [flowId] to allow routing of incoming messages.
|
||||
*/
|
||||
data class AddSessionBinding(val flowId: StateMachineRunId, val sessionId: SessionId) : Action()
|
||||
|
||||
/**
|
||||
* Remove the session bindings corresponding to [sessionIds].
|
||||
*/
|
||||
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
|
||||
|
||||
/**
|
||||
* Signal that the flow corresponding to [flowId] is considered started.
|
||||
*/
|
||||
data class SignalFlowHasStarted(val flowId: StateMachineRunId) : Action()
|
||||
|
||||
/**
|
||||
* Remove the flow corresponding to [flowId].
|
||||
*/
|
||||
data class RemoveFlow(
|
||||
val flowId: StateMachineRunId,
|
||||
val removalReason: FlowRemovalReason,
|
||||
val lastState: StateMachineState
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
* Schedule [event] to self.
|
||||
*/
|
||||
data class ScheduleEvent(val event: Event) : Action()
|
||||
|
||||
/**
|
||||
* Sleep until [time].
|
||||
*/
|
||||
data class SleepUntil(val time: Instant) : Action()
|
||||
|
||||
/**
|
||||
* Create a new database transaction.
|
||||
*/
|
||||
object CreateTransaction : Action() { override fun toString() = "CreateTransaction" }
|
||||
|
||||
/**
|
||||
* Roll back the current database transaction.
|
||||
*/
|
||||
object RollbackTransaction : Action() { override fun toString() = "RollbackTransaction" }
|
||||
|
||||
/**
|
||||
* Commit the current database transaction.
|
||||
*/
|
||||
object CommitTransaction : Action() { override fun toString() = "CommitTransaction" }
|
||||
|
||||
/**
|
||||
* Execute the specified [operation].
|
||||
*/
|
||||
data class ExecuteAsyncOperation(val operation: FlowAsyncOperation<*>) : Action()
|
||||
|
||||
/**
|
||||
* Release soft locks associated with given ID (currently the flow ID).
|
||||
*/
|
||||
data class ReleaseSoftLocks(val uuid: UUID?) : Action()
|
||||
}
|
||||
|
||||
/**
|
||||
* Reason for flow removal.
|
||||
*/
|
||||
sealed class FlowRemovalReason {
|
||||
data class OrderlyFinish(val flowReturnValue: Any?) : FlowRemovalReason()
|
||||
data class ErrorFinish(val flowErrors: List<FlowError>) : FlowRemovalReason()
|
||||
object SoftShutdown : FlowRemovalReason() { override fun toString() = "SoftShutdown" }
|
||||
// TODO Should we remove errored flows? How will the flow hospital work? Perhaps keep them in memory for a while, flush
|
||||
// them after a timeout, reload them on flow hospital request. In any case if we ever want to remove them
|
||||
// (e.g. temporarily) then add a case for that here.
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
|
||||
/**
|
||||
* An executor of a single [Action].
|
||||
*/
|
||||
interface ActionExecutor {
|
||||
/**
|
||||
* Execute [action] by [fiber].
|
||||
*/
|
||||
@Suspendable
|
||||
fun executeAction(fiber: FlowFiber, action: Action)
|
||||
}
|
@ -0,0 +1,233 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.codahale.metrics.*
|
||||
import net.corda.core.internal.concurrent.thenMatch
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.nodeapi.internal.persistence.contextDatabase
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
/**
|
||||
* This is the bottom execution engine of flow side-effects.
|
||||
*/
|
||||
class ActionExecutorImpl(
|
||||
private val services: ServiceHubInternal,
|
||||
private val checkpointStorage: CheckpointStorage,
|
||||
private val flowMessaging: FlowMessaging,
|
||||
private val stateMachineManager: StateMachineManagerInternal,
|
||||
private val checkpointSerializationContext: SerializationContext,
|
||||
metrics: MetricRegistry
|
||||
) : ActionExecutor {
|
||||
|
||||
private companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
/**
|
||||
* This [Gauge] just reports the sum of the bytes checkpointed during the last second.
|
||||
*/
|
||||
private class LatchedGauge(private val reservoir: Reservoir) : Gauge<Long> {
|
||||
override fun getValue(): Long {
|
||||
return reservoir.snapshot.values.sum()
|
||||
}
|
||||
}
|
||||
|
||||
private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate")
|
||||
private val checkpointSizesThisSecond = SlidingTimeWindowReservoir(1, TimeUnit.SECONDS)
|
||||
private val lastBandwidthUpdate = AtomicLong(0)
|
||||
private val checkpointBandwidthHist = metrics.register("Flows.CheckpointVolumeBytesPerSecondHist", Histogram(SlidingTimeWindowArrayReservoir(1, TimeUnit.DAYS)))
|
||||
private val checkpointBandwidth = metrics.register("Flows.CheckpointVolumeBytesPerSecondCurrent", LatchedGauge(checkpointSizesThisSecond))
|
||||
|
||||
@Suspendable
|
||||
override fun executeAction(fiber: FlowFiber, action: Action) {
|
||||
log.trace { "Flow ${fiber.id} executing $action" }
|
||||
return when (action) {
|
||||
is Action.TrackTransaction -> executeTrackTransaction(fiber, action)
|
||||
is Action.PersistCheckpoint -> executePersistCheckpoint(action)
|
||||
is Action.PersistDeduplicationFacts -> executePersistDeduplicationIds(action)
|
||||
is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action)
|
||||
is Action.PropagateErrors -> executePropagateErrors(action)
|
||||
is Action.ScheduleEvent -> executeScheduleEvent(fiber, action)
|
||||
is Action.SleepUntil -> executeSleepUntil(action)
|
||||
is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action)
|
||||
is Action.SendInitial -> executeSendInitial(action)
|
||||
is Action.SendExisting -> executeSendExisting(action)
|
||||
is Action.AddSessionBinding -> executeAddSessionBinding(action)
|
||||
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
|
||||
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
|
||||
is Action.RemoveFlow -> executeRemoveFlow(action)
|
||||
is Action.CreateTransaction -> executeCreateTransaction()
|
||||
is Action.RollbackTransaction -> executeRollbackTransaction()
|
||||
is Action.CommitTransaction -> executeCommitTransaction()
|
||||
is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action)
|
||||
is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action)
|
||||
}
|
||||
}
|
||||
|
||||
private fun executeReleaseSoftLocks(action: Action.ReleaseSoftLocks) {
|
||||
if (action.uuid != null) services.vaultService.softLockRelease(action.uuid)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) {
|
||||
services.validatedTransactions.trackTransaction(action.hash).thenMatch(
|
||||
success = { transaction ->
|
||||
fiber.scheduleEvent(Event.TransactionCommitted(transaction))
|
||||
},
|
||||
failure = { exception ->
|
||||
fiber.scheduleEvent(Event.Error(exception))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executePersistCheckpoint(action: Action.PersistCheckpoint) {
|
||||
val checkpointBytes = serializeCheckpoint(action.checkpoint)
|
||||
checkpointStorage.addCheckpoint(action.id, checkpointBytes)
|
||||
checkpointingMeter.mark()
|
||||
checkpointSizesThisSecond.update(checkpointBytes.size.toLong())
|
||||
var lastUpdateTime = lastBandwidthUpdate.get()
|
||||
while (System.nanoTime() - lastUpdateTime > TimeUnit.SECONDS.toNanos(1)) {
|
||||
if (lastBandwidthUpdate.compareAndSet(lastUpdateTime, System.nanoTime())) {
|
||||
val checkpointVolume = checkpointSizesThisSecond.snapshot.values.sum()
|
||||
checkpointBandwidthHist.update(checkpointVolume)
|
||||
}
|
||||
lastUpdateTime = lastBandwidthUpdate.get()
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationFacts) {
|
||||
for (handle in action.deduplicationHandlers) {
|
||||
handle.insideDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) {
|
||||
action.deduplicationHandlers.forEach {
|
||||
it.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executePropagateErrors(action: Action.PropagateErrors) {
|
||||
action.errorMessages.forEach { error ->
|
||||
val exception = error.flowException
|
||||
log.debug("Propagating error", exception)
|
||||
}
|
||||
for (sessionState in action.sessions) {
|
||||
// We cannot propagate if the session isn't live.
|
||||
if (sessionState.initiatedState !is InitiatedSessionState.Live) {
|
||||
continue
|
||||
}
|
||||
// Don't propagate errors to the originating session
|
||||
for (errorMessage in action.errorMessages) {
|
||||
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
|
||||
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
|
||||
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeScheduleEvent(fiber: FlowFiber, action: Action.ScheduleEvent) {
|
||||
fiber.scheduleEvent(action.event)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSleepUntil(action: Action.SleepUntil) {
|
||||
// TODO introduce explicit sleep state + wakeup event instead of relying on Fiber.sleep. This is so shutdown
|
||||
// conditions may "interrupt" the sleep instead of waiting until wakeup.
|
||||
val duration = Duration.between(Instant.now(), action.time)
|
||||
Fiber.sleep(duration.toNanos(), TimeUnit.NANOSECONDS)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) {
|
||||
checkpointStorage.removeCheckpoint(action.id)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSendInitial(action: Action.SendInitial) {
|
||||
flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSendExisting(action: Action.SendExisting) {
|
||||
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeAddSessionBinding(action: Action.AddSessionBinding) {
|
||||
stateMachineManager.addSessionBinding(action.flowId, action.sessionId)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeRemoveSessionBindings(action: Action.RemoveSessionBindings) {
|
||||
stateMachineManager.removeSessionBindings(action.sessionIds)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSignalFlowHasStarted(action: Action.SignalFlowHasStarted) {
|
||||
stateMachineManager.signalFlowHasStarted(action.flowId)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeRemoveFlow(action: Action.RemoveFlow) {
|
||||
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeCreateTransaction() {
|
||||
if (contextTransactionOrNull != null) {
|
||||
throw IllegalStateException("Refusing to create a second transaction")
|
||||
}
|
||||
contextDatabase.newTransaction()
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeRollbackTransaction() {
|
||||
contextTransactionOrNull?.close()
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeCommitTransaction() {
|
||||
try {
|
||||
contextTransaction.commit()
|
||||
} finally {
|
||||
contextTransaction.close()
|
||||
contextTransactionOrNull = null
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeAsyncOperation(fiber: FlowFiber, action: Action.ExecuteAsyncOperation) {
|
||||
val operationFuture = action.operation.execute()
|
||||
operationFuture.thenMatch(
|
||||
success = { result ->
|
||||
fiber.scheduleEvent(Event.AsyncOperationCompletion(result))
|
||||
},
|
||||
failure = { exception ->
|
||||
fiber.scheduleEvent(Event.Error(exception))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
|
||||
return checkpoint.serialize(context = checkpointSerializationContext)
|
||||
}
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.strands.concurrent.AbstractQueuedSynchronizer
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
|
||||
/**
|
||||
* Quasar-compatible latch that may be incremented.
|
||||
*/
|
||||
class CountUpDownLatch(initialValue: Int) {
|
||||
|
||||
// See quasar CountDownLatch
|
||||
private class Sync(initialValue: Int) : AbstractQueuedSynchronizer() {
|
||||
init {
|
||||
state = initialValue
|
||||
}
|
||||
|
||||
override fun tryAcquireShared(arg: Int): Int {
|
||||
if (arg >= 0) {
|
||||
return if (state == arg) 1 else -1
|
||||
} else {
|
||||
return if (state <= -arg) 1 else -1
|
||||
}
|
||||
}
|
||||
|
||||
override fun tryReleaseShared(arg: Int): Boolean {
|
||||
while (true) {
|
||||
val c = state
|
||||
if (c == 0)
|
||||
return false
|
||||
val nextc = c - Math.min(c, arg)
|
||||
if (compareAndSetState(c, nextc))
|
||||
return nextc == 0
|
||||
}
|
||||
}
|
||||
|
||||
fun increment() {
|
||||
while (true) {
|
||||
val c = state
|
||||
val nextc = c + 1
|
||||
if (compareAndSetState(c, nextc))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val sync = Sync(initialValue)
|
||||
|
||||
@Suspendable
|
||||
fun await() {
|
||||
sync.acquireSharedInterruptibly(0)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
fun awaitLessThanOrEqual(number: Int) {
|
||||
sync.acquireSharedInterruptibly(number)
|
||||
}
|
||||
|
||||
fun countDown(number: Int = 1) {
|
||||
require(number > 0)
|
||||
sync.releaseShared(number)
|
||||
}
|
||||
|
||||
fun countUp() {
|
||||
sync.increment()
|
||||
}
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* A deduplication ID of a flow message.
|
||||
*/
|
||||
data class DeduplicationId(val toString: String) {
|
||||
companion object {
|
||||
/**
|
||||
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
|
||||
* unless we persist the ID somehow.
|
||||
*/
|
||||
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
|
||||
|
||||
/**
|
||||
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
|
||||
* creating IDs in case the message-generating flow logic is replayed on hard failure.
|
||||
*
|
||||
* A normal deduplication ID consists of:
|
||||
* 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the
|
||||
* initiator's session ID.
|
||||
* 2. The number of *clean* suspends since the start of the flow.
|
||||
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
|
||||
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
|
||||
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
|
||||
* message-id map to change, which means deduplication will not happen correctly.
|
||||
*/
|
||||
fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId {
|
||||
return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index")
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
|
||||
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
|
||||
* dirtiness state to be thrown away for resumption.
|
||||
*
|
||||
* An error deduplication ID consists of:
|
||||
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
|
||||
* See [net.corda.core.flows.IdentifiableException].
|
||||
* 2. The recipient's session ID.
|
||||
*/
|
||||
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
|
||||
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,128 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from
|
||||
* outside (e.g. in case of message delivery or external event).
|
||||
*/
|
||||
sealed class Event {
|
||||
/**
|
||||
* Check the current state for pending work. For example if the flow is waiting for a message from a particular
|
||||
* session this event may cause a flow resume if we have a corresponding message. In general the state machine
|
||||
* should be idempotent in the [DoRemainingWork] event, meaning a second subsequent event shouldn't modify the state
|
||||
* or produce [Action]s.
|
||||
*/
|
||||
object DoRemainingWork : Event() { override fun toString() = "DoRemainingWork" }
|
||||
|
||||
/**
|
||||
* Deliver a session message.
|
||||
* @param sessionMessage the message itself.
|
||||
* @param deduplicationHandler the handle to acknowledge the message after checkpointing.
|
||||
* @param sender the sender [Party].
|
||||
*/
|
||||
data class DeliverSessionMessage(
|
||||
val sessionMessage: ExistingSessionMessage,
|
||||
val deduplicationHandler: DeduplicationHandler,
|
||||
val sender: Party
|
||||
) : Event()
|
||||
|
||||
/**
|
||||
* Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error.
|
||||
* @param exception the exception itself.
|
||||
*/
|
||||
data class Error(val exception: Throwable) : Event()
|
||||
|
||||
/**
|
||||
* Signal that a ledger transaction has committed. This is an event completing a [FlowIORequest.WaitForLedgerCommit]
|
||||
* suspension.
|
||||
* @param transaction the transaction that was committed.
|
||||
*/
|
||||
data class TransactionCommitted(val transaction: SignedTransaction) : Event()
|
||||
|
||||
/**
|
||||
* Trigger a soft shutdown, removing the flow as soon as possible. This causes the flow to be removed as soon as
|
||||
* this event is processed. Note that on restart the flow will resume as normal.
|
||||
*/
|
||||
object SoftShutdown : Event() { override fun toString() = "SoftShutdown" }
|
||||
|
||||
/**
|
||||
* Start error propagation on a errored flow. This may be triggered by e.g. a [FlowHospital].
|
||||
*/
|
||||
object StartErrorPropagation : Event() { override fun toString() = "StartErrorPropagation" }
|
||||
|
||||
/**
|
||||
*
|
||||
* Scheduled by the flow.
|
||||
*
|
||||
* Initiate a flow. This causes a new session object to be created and returned to the flow. Note that no actual
|
||||
* communication takes place at this time, only on the first send/receive operation on the session.
|
||||
* @param party the [Party] to create a session with.
|
||||
*/
|
||||
data class InitiateFlow(val party: Party) : Event()
|
||||
|
||||
/**
|
||||
* Signal the entering into a subflow.
|
||||
*
|
||||
* Scheduled and executed by the flow.
|
||||
*
|
||||
* @param subFlowClass the [Class] of the subflow, to be used to determine whether it's Initiating or inlined.
|
||||
*/
|
||||
data class EnterSubFlow(val subFlowClass: Class<FlowLogic<*>>) : Event()
|
||||
|
||||
/**
|
||||
* Signal the leaving of a subflow.
|
||||
*
|
||||
* Scheduled by the flow.
|
||||
*
|
||||
*/
|
||||
object LeaveSubFlow : Event() { override fun toString() = "LeaveSubFlow" }
|
||||
|
||||
/**
|
||||
* Signal a flow suspension. This causes the flow's stack and the state machine's state together with the suspending
|
||||
* IO request to be persisted into the database.
|
||||
*
|
||||
* Scheduled by the flow and executed inside the park closure.
|
||||
*
|
||||
* @param ioRequest the request triggering the suspension.
|
||||
* @param maySkipCheckpoint indicates whether the persistence may be skipped.
|
||||
* @param fiber the serialised stack of the flow.
|
||||
*/
|
||||
data class Suspend(
|
||||
val ioRequest: FlowIORequest<*>,
|
||||
val maySkipCheckpoint: Boolean,
|
||||
val fiber: SerializedBytes<FlowStateMachineImpl<*>>
|
||||
) : Event() {
|
||||
override fun toString() =
|
||||
"Suspend(" +
|
||||
"ioRequest=$ioRequest, " +
|
||||
"maySkipCheckpoint=$maySkipCheckpoint, " +
|
||||
"fiber=${fiber.hash}, " +
|
||||
")"
|
||||
}
|
||||
|
||||
/**
|
||||
* Signals clean flow finish.
|
||||
*
|
||||
* Scheduled by the flow.
|
||||
*
|
||||
* @param returnValue the return value of the flow.
|
||||
* @param softLocksId the flow ID of the flow if it is holding soft locks, else null.
|
||||
*/
|
||||
data class FlowFinish(val returnValue: Any?, val softLocksId: UUID?) : Event()
|
||||
|
||||
/**
|
||||
* Signals the completion of a [FlowAsyncOperation].
|
||||
*
|
||||
* Scheduling is triggered by the service that completes the future returned by the async operation.
|
||||
*
|
||||
* @param returnValue the result of the operation.
|
||||
*/
|
||||
data class AsyncOperationCompletion(val returnValue: Any?) : Event()
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
|
||||
/**
|
||||
* An interface wrapping a fiber running a flow.
|
||||
*/
|
||||
interface FlowFiber {
|
||||
val id: StateMachineRunId
|
||||
val stateMachine: StateMachine
|
||||
|
||||
@Suspendable
|
||||
fun scheduleEvent(event: Event)
|
||||
|
||||
fun snapshot(): StateMachineState
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
/**
|
||||
* A flow hospital is a class that is notified when a flow transitions into an error state due to an uncaught exception
|
||||
* or internal error condition, and when it becomes clean again (e.g. due to a resume).
|
||||
* Also see [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor].
|
||||
*/
|
||||
interface FlowHospital {
|
||||
/**
|
||||
* The flow running in [flowFiber] has errored.
|
||||
*/
|
||||
fun flowErrored(flowFiber: FlowFiber)
|
||||
|
||||
/**
|
||||
* The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume.
|
||||
*/
|
||||
fun flowCleaned(flowFiber: FlowFiber)
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.Party
|
||||
import java.time.Instant
|
||||
|
||||
interface FlowIORequest {
|
||||
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
||||
// don't have the original stack trace because it's in a suspended fiber.
|
||||
val stackTraceInCaseOfProblems: StackSnapshot
|
||||
}
|
||||
|
||||
interface WaitingRequest : FlowIORequest {
|
||||
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
|
||||
}
|
||||
|
||||
interface SessionedFlowIORequest : FlowIORequest {
|
||||
val session: FlowSessionInternal
|
||||
}
|
||||
|
||||
interface SendRequest : SessionedFlowIORequest {
|
||||
val message: SessionMessage
|
||||
}
|
||||
|
||||
interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest {
|
||||
val userReceiveType: Class<*>?
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
|
||||
}
|
||||
|
||||
data class SendAndReceive(override val session: FlowSessionInternal,
|
||||
override val message: SessionMessage,
|
||||
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
data class ReceiveOnly(
|
||||
override val session: FlowSessionInternal,
|
||||
override val userReceiveType: Class<*>?
|
||||
) : ReceiveRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
class ReceiveAll(val requests: List<ReceiveRequest>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
|
||||
return received.keys == requests.map { it.session }.toSet()
|
||||
}
|
||||
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
|
||||
|
||||
private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean {
|
||||
return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage }
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
|
||||
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
|
||||
|
||||
poll(receivedMessages)
|
||||
return if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
suspend(this)
|
||||
poll(receivedMessages)
|
||||
if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface Suspend {
|
||||
@Suspendable
|
||||
operator fun invoke(request: FlowIORequest)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
|
||||
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
|
||||
poll(request)?.let {
|
||||
receivedMessages[request.session] = RequestMessage(request, it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(request: ReceiveRequest): ExistingSessionMessage? {
|
||||
return request.session.receivedMessages.poll()?.message
|
||||
}
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
|
||||
|
||||
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
|
||||
|
||||
data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage)
|
||||
}
|
||||
|
||||
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage
|
||||
}
|
||||
|
||||
data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
@ -156,4 +156,4 @@ class FlowLogicRefFactoryImpl(private val classloader: ClassLoader) : SingletonS
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,89 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.esotericsoftware.kryo.KryoException
|
||||
import net.corda.core.context.InvocationOrigin
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
|
||||
import java.io.NotSerializableException
|
||||
|
||||
/**
|
||||
* A wrapper interface around flow messaging.
|
||||
*/
|
||||
interface FlowMessaging {
|
||||
/**
|
||||
* Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to
|
||||
* listen on the send acknowledgement.
|
||||
*/
|
||||
@Suspendable
|
||||
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId)
|
||||
|
||||
/**
|
||||
* Start the messaging using the [onMessage] message handler.
|
||||
*/
|
||||
fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
|
||||
*/
|
||||
class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
|
||||
companion object {
|
||||
val log = contextLogger()
|
||||
|
||||
val sessionTopic = "platform.session"
|
||||
}
|
||||
|
||||
override fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) {
|
||||
serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, deduplicationHandler ->
|
||||
onMessage(receivedMessage, deduplicationHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) {
|
||||
log.trace { "Sending message $deduplicationId $message to party $party" }
|
||||
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
|
||||
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")
|
||||
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
||||
val sequenceKey = when (message) {
|
||||
is InitialSessionMessage -> message.initiatorSessionId
|
||||
is ExistingSessionMessage -> message.recipientSessionId
|
||||
}
|
||||
serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey)
|
||||
}
|
||||
|
||||
private fun SessionMessage.additionalHeaders(target: Party): Map<String, String> {
|
||||
// This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator.
|
||||
// It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case.
|
||||
val mightDeadlockDrainingTarget = FlowStateMachineImpl.currentStateMachine()?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name }
|
||||
return when {
|
||||
this !is InitialSessionMessage || mightDeadlockDrainingTarget -> emptyMap()
|
||||
else -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE)
|
||||
}
|
||||
}
|
||||
|
||||
private fun serializeSessionMessage(message: SessionMessage): SerializedBytes<SessionMessage> {
|
||||
return try {
|
||||
message.serialize()
|
||||
} catch (exception: Exception) {
|
||||
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
|
||||
if ((exception is KryoException || exception is NotSerializableException)
|
||||
&& message is ExistingSessionMessage && message.payload is ErrorSessionMessage) {
|
||||
val error = message.payload.flowException
|
||||
val rewrappedError = FlowException(error?.message)
|
||||
message.copy(payload = message.payload.copy(flowException = rewrappedError)).serialize()
|
||||
} else {
|
||||
throw exception
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,20 +1,40 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.NonEmptySet
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import net.corda.core.utilities.checkPayloadIs
|
||||
|
||||
class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
|
||||
internal lateinit var stateMachine: FlowStateMachine<*>
|
||||
internal lateinit var sessionFlow: FlowLogic<*>
|
||||
class FlowSessionImpl(
|
||||
override val counterparty: Party,
|
||||
val sourceSessionId: SessionId
|
||||
) : FlowSession() {
|
||||
|
||||
override fun toString() = "FlowSessionImpl(counterparty=$counterparty, sourceSessionId=$sourceSessionId)"
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return (other as? FlowSessionImpl)?.sourceSessionId == sourceSessionId
|
||||
}
|
||||
|
||||
override fun hashCode() = sourceSessionId.hashCode()
|
||||
|
||||
private fun getFlowStateMachine(): FlowStateMachine<*> {
|
||||
return Fiber.currentFiber() as FlowStateMachine<*>
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo {
|
||||
return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint)
|
||||
val request = FlowIORequest.GetFlowInfo(NonEmptySet.of(this))
|
||||
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -26,14 +46,15 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
|
||||
payload: Any,
|
||||
maySkipCheckpoint: Boolean
|
||||
): UntrustworthyData<R> {
|
||||
return stateMachine.sendAndReceive(
|
||||
receiveType,
|
||||
counterparty,
|
||||
payload,
|
||||
sessionFlow,
|
||||
retrySend = false,
|
||||
maySkipCheckpoint = maySkipCheckpoint
|
||||
enforceNotPrimitive(receiveType)
|
||||
val request = FlowIORequest.SendAndReceive(
|
||||
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
|
||||
shouldRetrySend = false
|
||||
)
|
||||
val responseValues: Map<FlowSession, SerializedBytes<Any>> = getFlowStateMachine().suspend(request, maySkipCheckpoint)
|
||||
val responseForCurrentSession = responseValues[this]!!
|
||||
|
||||
return responseForCurrentSession.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -41,7 +62,9 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
|
||||
|
||||
@Suspendable
|
||||
override fun <R : Any> receive(receiveType: Class<R>, maySkipCheckpoint: Boolean): UntrustworthyData<R> {
|
||||
return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint)
|
||||
enforceNotPrimitive(receiveType)
|
||||
val request = FlowIORequest.Receive(NonEmptySet.of(this))
|
||||
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -49,12 +72,17 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
|
||||
|
||||
@Suspendable
|
||||
override fun send(payload: Any, maySkipCheckpoint: Boolean) {
|
||||
return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint)
|
||||
val request = FlowIORequest.Send(
|
||||
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
|
||||
shouldRetrySend = false
|
||||
)
|
||||
return getFlowStateMachine().suspend(request, maySkipCheckpoint)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(payload: Any) = send(payload, maySkipCheckpoint = false)
|
||||
|
||||
override fun toString() = "Flow session with $counterparty"
|
||||
private fun enforceNotPrimitive(type: Class<*>) {
|
||||
require(!type.isPrimitive) { "Cannot receive primitive type $type" }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,66 +0,0 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiated
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
/**
|
||||
* @param retryable Indicates that the session initialisation should be retried until an expected [SessionData] response
|
||||
* is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow
|
||||
* that only sends back a single [SessionData] message before termination.
|
||||
*/
|
||||
// TODO rename this
|
||||
class FlowSessionInternal(
|
||||
val flow: FlowLogic<*>,
|
||||
val flowSession : FlowSession,
|
||||
val ourSessionId: SessionId,
|
||||
val initiatingParty: Party?,
|
||||
var state: FlowSessionState,
|
||||
var retryable: Boolean = false) {
|
||||
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage>()
|
||||
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
|
||||
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)"
|
||||
}
|
||||
|
||||
fun getPeerSessionId(): SessionId {
|
||||
val sessionState = state
|
||||
return when (sessionState) {
|
||||
is FlowSessionState.Initiated -> sessionState.peerSessionId
|
||||
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage)
|
||||
|
||||
/**
|
||||
* [FlowSessionState] describes the session's state.
|
||||
*
|
||||
* [Uninitiated] is pre-handshake, where no communication has happened. [Initiating.otherParty] at this point holds a
|
||||
* [Party] corresponding to either a specific peer or a service.
|
||||
* [Initiating] is pre-handshake, where the initiating message has been sent.
|
||||
* [Initiated] is post-handshake. At this point [Initiating.otherParty] will have been resolved to a specific peer
|
||||
* [Initiated.peerParty], and the peer's sessionId has been initialised.
|
||||
*/
|
||||
sealed class FlowSessionState {
|
||||
abstract val sendToParty: Party
|
||||
|
||||
data class Uninitiated(val otherParty: Party) : FlowSessionState() {
|
||||
override val sendToParty: Party get() = otherParty
|
||||
}
|
||||
|
||||
/** [otherParty] may be a specific peer or a service party */
|
||||
data class Initiating(val otherParty: Party) : FlowSessionState() {
|
||||
override val sendToParty: Party get() = otherParty
|
||||
}
|
||||
|
||||
data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() {
|
||||
override val sendToParty: Party get() = peerParty
|
||||
}
|
||||
}
|
@ -5,262 +5,251 @@ import co.paralleluniverse.fibers.Fiber.parkAndSerialize
|
||||
import co.paralleluniverse.fibers.FiberScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import com.google.common.primitives.Primitives
|
||||
import co.paralleluniverse.strands.channels.Channel
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.api.FlowAppAuditEvent
|
||||
import net.corda.node.services.api.FlowPermissionAuditEvent
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.logging.pushToLoggingContext
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.IOException
|
||||
import java.sql.SQLException
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.reflect.KProperty1
|
||||
|
||||
class FlowPermissionException(message: String) : FlowException(message)
|
||||
|
||||
class TransientReference<out A>(@Transient val value: A)
|
||||
|
||||
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
override val logic: FlowLogic<R>,
|
||||
scheduler: FiberScheduler,
|
||||
val ourIdentity: Party,
|
||||
override val context: InvocationContext) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
|
||||
|
||||
scheduler: FiberScheduler
|
||||
) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R>, FlowFiber {
|
||||
companion object {
|
||||
// Used to work around a small limitation in Quasar.
|
||||
private val QUASAR_UNBLOCKER = Fiber::class.staticField<Any>("SERIALIZER_BLOCKER").value
|
||||
|
||||
/**
|
||||
* Return the current [FlowStateMachineImpl] or null if executing outside of one.
|
||||
*/
|
||||
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
|
||||
|
||||
private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
|
||||
}
|
||||
|
||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||
@Transient override lateinit var serviceHub: ServiceHubInternal
|
||||
@Transient override lateinit var ourIdentityAndCert: PartyAndCertificate
|
||||
@Transient internal lateinit var database: CordaPersistence
|
||||
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
|
||||
@Transient internal lateinit var actionOnEnd: (Try<R>, Boolean) -> Unit
|
||||
@Transient internal var fromCheckpoint: Boolean = false
|
||||
@Transient private var txTrampoline: DatabaseTransaction? = null
|
||||
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
|
||||
|
||||
data class TransientValues(
|
||||
val eventQueue: Channel<Event>,
|
||||
val resultFuture: CordaFuture<Any?>,
|
||||
val database: CordaPersistence,
|
||||
val transitionExecutor: TransitionExecutor,
|
||||
val actionExecutor: ActionExecutor,
|
||||
val stateMachine: StateMachine,
|
||||
val serviceHub: ServiceHubInternal,
|
||||
val checkpointSerializationContext: SerializationContext
|
||||
)
|
||||
|
||||
internal var transientValues: TransientReference<TransientValues>? = null
|
||||
internal var transientState: TransientReference<StateMachineState>? = null
|
||||
|
||||
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
|
||||
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
|
||||
return field.get(suppliedValues.value)
|
||||
}
|
||||
|
||||
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
|
||||
val transaction = contextTransaction
|
||||
contextTransactionOrNull = null
|
||||
return TransientReference(transaction)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
|
||||
* is not necessary.
|
||||
*/
|
||||
override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id")
|
||||
@Transient private var resultFutureTransient: OpenFuture<R>? = openFuture()
|
||||
private val _resultFuture get() = resultFutureTransient ?: openFuture<R>().also { resultFutureTransient = it }
|
||||
/** This future will complete when the call method returns. */
|
||||
override val resultFuture: CordaFuture<R> get() = _resultFuture
|
||||
// This state IS serialised, as we need it to know what the fiber is waiting for.
|
||||
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSessionInternal>()
|
||||
internal var waitingForResponse: WaitingRequest? = null
|
||||
override val logger = log
|
||||
override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture))
|
||||
override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext
|
||||
override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity
|
||||
internal var hasSoftLockedStates: Boolean = false
|
||||
set(value) {
|
||||
if (value) field = value else throw IllegalArgumentException("Can only set to true")
|
||||
}
|
||||
|
||||
init {
|
||||
logic.stateMachine = this
|
||||
/**
|
||||
* Processes an event by creating the associated transition and executing it using the given executor.
|
||||
* Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately]
|
||||
* instead.
|
||||
*/
|
||||
@Suspendable
|
||||
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
|
||||
val stateMachine = getTransientField(TransientValues::stateMachine)
|
||||
val oldState = transientState!!.value
|
||||
val actionExecutor = getTransientField(TransientValues::actionExecutor)
|
||||
val transition = stateMachine.transition(event, oldState)
|
||||
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
|
||||
transientState = TransientReference(newState)
|
||||
return continuation
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the events in the event queue until a transition indicates that control should be returned to user code
|
||||
* in the form of a regular resume or a throw of an exception. Alternatively the transition may abort the fiber
|
||||
* completely.
|
||||
*
|
||||
* @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the
|
||||
* processing of the eventloop. Purely used for internal invariant checks.
|
||||
* @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the eventloop
|
||||
* processing finished. Purely used for internal invariant checks.
|
||||
*/
|
||||
@Suspendable
|
||||
private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? {
|
||||
checkDbTransaction(isDbTransactionOpenOnEntry)
|
||||
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
|
||||
val eventQueue = getTransientField(TransientValues::eventQueue)
|
||||
try {
|
||||
eventLoop@while (true) {
|
||||
val nextEvent = eventQueue.receive()
|
||||
val continuation = processEvent(transitionExecutor, nextEvent)
|
||||
when (continuation) {
|
||||
is FlowContinuation.Resume -> return continuation.result
|
||||
is FlowContinuation.Throw -> {
|
||||
continuation.throwable.fillInStackTrace()
|
||||
throw continuation.throwable
|
||||
}
|
||||
FlowContinuation.ProcessEvents -> continue@eventLoop
|
||||
FlowContinuation.Abort -> abortFiber()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
checkDbTransaction(isDbTransactionOpenOnExit)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Immediately processes the passed in event. Always called with an open database transaction.
|
||||
*
|
||||
* @param event the event to be processed.
|
||||
* @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the
|
||||
* processing of the event. Purely used for internal invariant checks.
|
||||
* @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the event
|
||||
* processing finished. Purely used for internal invariant checks.
|
||||
*/
|
||||
@Suspendable
|
||||
private fun processEventImmediately(event: Event, isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): FlowContinuation {
|
||||
checkDbTransaction(isDbTransactionOpenOnEntry)
|
||||
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
|
||||
val continuation = processEvent(transitionExecutor, event)
|
||||
checkDbTransaction(isDbTransactionOpenOnExit)
|
||||
return continuation
|
||||
}
|
||||
|
||||
private fun checkDbTransaction(isPresent: Boolean) {
|
||||
if (isPresent) {
|
||||
requireNotNull(contextTransactionOrNull != null)
|
||||
} else {
|
||||
require(contextTransactionOrNull == null)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun run() {
|
||||
createTransaction()
|
||||
logic.stateMachine = this
|
||||
|
||||
context.pushToLoggingContext()
|
||||
|
||||
initialiseFlow()
|
||||
|
||||
logger.debug { "Calling flow: $logic" }
|
||||
val startTime = System.nanoTime()
|
||||
val result = try {
|
||||
val r = logic.call()
|
||||
// Only sessions which have done a single send and nothing else will block here
|
||||
openSessions.values
|
||||
.filter { it.state is Initiating }
|
||||
.forEach { it.waitForConfirmation() }
|
||||
r
|
||||
} catch (e: FlowException) {
|
||||
recordDuration(startTime, success = false)
|
||||
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
|
||||
val propagated = e.stackTrace[0].className == javaClass.name
|
||||
processException(e, propagated)
|
||||
logger.warn(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e)
|
||||
return
|
||||
} catch (t: Throwable) {
|
||||
recordDuration(startTime, success = false)
|
||||
logger.warn("Terminated by unexpected exception", t)
|
||||
processException(t, false)
|
||||
return
|
||||
val resultOrError = try {
|
||||
val result = logic.call()
|
||||
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true)
|
||||
Try.Success(result)
|
||||
} catch (throwable: Throwable) {
|
||||
logger.warn("Flow threw exception", throwable)
|
||||
Try.Failure<R>(throwable)
|
||||
}
|
||||
val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null
|
||||
val finalEvent = when (resultOrError) {
|
||||
is Try.Success -> {
|
||||
Event.FlowFinish(resultOrError.value, softLocksId)
|
||||
}
|
||||
is Try.Failure -> {
|
||||
Event.Error(resultOrError.exception)
|
||||
}
|
||||
}
|
||||
// Immediately process the last event. This is to make sure the transition can assume that it has an open
|
||||
// database transaction.
|
||||
val continuation = processEventImmediately(
|
||||
finalEvent,
|
||||
isDbTransactionOpenOnEntry = true,
|
||||
isDbTransactionOpenOnExit = false
|
||||
)
|
||||
if (continuation == FlowContinuation.ProcessEvents) {
|
||||
// This can happen in case there was an error and there are further things to do e.g. to propagate it.
|
||||
processEventsUntilFlowIsResumed(
|
||||
isDbTransactionOpenOnEntry = false,
|
||||
isDbTransactionOpenOnExit = false
|
||||
)
|
||||
}
|
||||
|
||||
recordDuration(startTime)
|
||||
// This is to prevent actionOnEnd being called twice if it throws an exception
|
||||
actionOnEnd(Try.Success(result), false)
|
||||
_resultFuture.set(result)
|
||||
logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" }
|
||||
}
|
||||
|
||||
private fun createTransaction() {
|
||||
// Make sure we have a database transaction
|
||||
database.createTransaction()
|
||||
logger.trace { "Starting database transaction $contextTransactionOrNull on ${Strand.currentStrand()}" }
|
||||
}
|
||||
|
||||
private fun processException(exception: Throwable, propagated: Boolean) {
|
||||
actionOnEnd(Try.Failure(exception), propagated)
|
||||
_resultFuture.setException(exception)
|
||||
logic.progressTracker?.endWithError(exception)
|
||||
}
|
||||
|
||||
internal fun commitTransaction() {
|
||||
val transaction = contextTransaction
|
||||
try {
|
||||
logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." }
|
||||
transaction.commit()
|
||||
} catch (e: SQLException) {
|
||||
// TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly.
|
||||
logger.error("Transaction commit failed: ${e.message}", e)
|
||||
System.exit(1)
|
||||
} finally {
|
||||
transaction.close()
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession {
|
||||
val sessionKey = Pair(sessionFlow, otherParty)
|
||||
if (openSessions.containsKey(sessionKey)) {
|
||||
throw IllegalStateException(
|
||||
"Attempted to initiateFlow() twice in the same InitiatingFlow $sessionFlow for the same party " +
|
||||
"$otherParty. This isn't supported in this version of Corda. Alternatively you may " +
|
||||
"initiate a new flow by calling initiateFlow() in an " +
|
||||
"@${InitiatingFlow::class.java.simpleName} sub-flow."
|
||||
private fun initialiseFlow() {
|
||||
processEventsUntilFlowIsResumed(
|
||||
isDbTransactionOpenOnEntry = false,
|
||||
isDbTransactionOpenOnExit = true
|
||||
)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <R> subFlow(subFlow: FlowLogic<R>): R {
|
||||
processEventImmediately(
|
||||
Event.EnterSubFlow(subFlow.javaClass),
|
||||
isDbTransactionOpenOnEntry = true,
|
||||
isDbTransactionOpenOnExit = true
|
||||
)
|
||||
return try {
|
||||
subFlow.call()
|
||||
} finally {
|
||||
processEventImmediately(
|
||||
Event.LeaveSubFlow,
|
||||
isDbTransactionOpenOnEntry = true,
|
||||
isDbTransactionOpenOnExit = true
|
||||
)
|
||||
}
|
||||
val flowSession = FlowSessionImpl(otherParty)
|
||||
createNewSession(otherParty, flowSession, sessionFlow)
|
||||
flowSession.stateMachine = this
|
||||
flowSession.sessionFlow = sessionFlow
|
||||
return flowSession
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo {
|
||||
val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated
|
||||
return state.context
|
||||
override fun initiateFlow(party: Party): FlowSession {
|
||||
val resume = processEventImmediately(
|
||||
Event.InitiateFlow(party),
|
||||
isDbTransactionOpenOnEntry = true,
|
||||
isDbTransactionOpenOnExit = true
|
||||
) as FlowContinuation.Resume
|
||||
return resume.result as FlowSession
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
payload: Any,
|
||||
sessionFlow: FlowLogic<*>,
|
||||
retrySend: Boolean,
|
||||
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
|
||||
requireNonPrimitive(receiveType)
|
||||
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
|
||||
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
|
||||
val receivedSessionMessage: ReceivedSessionMessage = if (session == null) {
|
||||
val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
|
||||
// Only do a receive here as the session init has carried the payload
|
||||
receiveInternal(newSession, receiveType)
|
||||
} else {
|
||||
val sendData = createSessionData(session, payload)
|
||||
sendAndReceiveInternal(session, sendData, receiveType)
|
||||
private fun abortFiber(): Nothing {
|
||||
while (true) {
|
||||
Fiber.park()
|
||||
}
|
||||
val sessionData = receivedSessionMessage.message.checkDataSessionMessage()
|
||||
logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" }
|
||||
return sessionData.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage {
|
||||
when (payload) {
|
||||
is DataSessionMessage -> {
|
||||
return payload
|
||||
}
|
||||
else -> {
|
||||
throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> receive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
sessionFlow: FlowLogic<*>,
|
||||
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
|
||||
requireNonPrimitive(receiveType)
|
||||
logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
|
||||
val session = getConfirmedSession(otherParty, sessionFlow)
|
||||
val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage()
|
||||
logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" }
|
||||
return receivedSessionMessage.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
private fun requireNonPrimitive(receiveType: Class<*>) {
|
||||
require(!receiveType.isPrimitive) {
|
||||
"Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class"
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) {
|
||||
logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }
|
||||
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
|
||||
if (session == null) {
|
||||
// Don't send the payload again if it was already piggy-backed on a session init
|
||||
initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false)
|
||||
} else {
|
||||
sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload)))
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction {
|
||||
logger.debug { "waitForLedgerCommit($hash) ..." }
|
||||
suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>))
|
||||
val stx = serviceHub.validatedTransactions.getTransaction(hash)
|
||||
if (stx != null) {
|
||||
logger.debug { "Transaction $hash committed to ledger" }
|
||||
return stx
|
||||
}
|
||||
|
||||
// If the tx isn't committed then we may have been resumed due to an session ending in an error
|
||||
for (session in openSessions.values) {
|
||||
for (receivedMessage in session.receivedMessages) {
|
||||
if (receivedMessage.message.payload is ErrorSessionMessage) {
|
||||
session.erroredEnd(receivedMessage.message.payload.flowException)
|
||||
}
|
||||
}
|
||||
}
|
||||
throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
|
||||
}
|
||||
|
||||
// Provide a mechanism to sleep within a Strand without locking any transactional state.
|
||||
// This checkpoints, since we cannot undo any database writes up to this point.
|
||||
@Suspendable
|
||||
override fun sleepUntil(until: Instant) {
|
||||
suspend(Sleep(until, this))
|
||||
}
|
||||
|
||||
// TODO Dummy implementation of access to application specific permission controls and audit logging
|
||||
@ -305,242 +294,50 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
val requests = ArrayList<ReceiveOnly>()
|
||||
for ((session, receiveType) in sessions) {
|
||||
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
|
||||
requests.add(ReceiveOnly(sessionInternal, receiveType))
|
||||
}
|
||||
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
|
||||
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
|
||||
for ((sessionInternal, requestAndMessage) in receivedMessages) {
|
||||
val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session)
|
||||
result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs(
|
||||
requestAndMessage.request.userReceiveType as Class<out Any>
|
||||
)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
internal fun pushToLoggingContext() = context.pushToLoggingContext()
|
||||
|
||||
/**
|
||||
* This method will suspend the state machine and wait for incoming session init response from other party.
|
||||
*/
|
||||
@Suspendable
|
||||
private fun FlowSessionInternal.waitForConfirmation() {
|
||||
val sessionInitResponse = receiveInternal(this, null)
|
||||
val payload = sessionInitResponse.message.payload
|
||||
when (payload) {
|
||||
is ConfirmSessionMessage -> {
|
||||
state = FlowSessionState.Initiated(
|
||||
sessionInitResponse.
|
||||
peerParty,
|
||||
payload.initiatedSessionId,
|
||||
payload.initiatedFlowInfo)
|
||||
}
|
||||
is RejectSessionMessage -> {
|
||||
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}")
|
||||
}
|
||||
else -> {
|
||||
throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage {
|
||||
return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message))
|
||||
|
||||
@Suspendable
|
||||
private fun receiveInternal(
|
||||
session: FlowSessionInternal,
|
||||
userReceiveType: Class<*>?): ReceivedSessionMessage {
|
||||
return waitForMessage(ReceiveOnly(session, userReceiveType))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun sendAndReceiveInternal(
|
||||
session: FlowSessionInternal,
|
||||
message: DataSessionMessage,
|
||||
userReceiveType: Class<*>?): ReceivedSessionMessage {
|
||||
val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message)
|
||||
return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? {
|
||||
val session = openSessions[Pair(sessionFlow, otherParty)] ?: return null
|
||||
return when (session.state) {
|
||||
is FlowSessionState.Uninitiated -> null
|
||||
is FlowSessionState.Initiating -> {
|
||||
session.waitForConfirmation()
|
||||
session
|
||||
}
|
||||
is FlowSessionState.Initiated -> session
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal {
|
||||
return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?:
|
||||
initiateSession(otherParty, sessionFlow, null, waitForConfirmation = true)
|
||||
}
|
||||
|
||||
private fun createNewSession(
|
||||
otherParty: Party,
|
||||
flowSession: FlowSession,
|
||||
sessionFlow: FlowLogic<*>
|
||||
) {
|
||||
logger.trace { "Creating a new session with $otherParty" }
|
||||
val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty))
|
||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun initiateSession(
|
||||
otherParty: Party,
|
||||
sessionFlow: FlowLogic<*>,
|
||||
firstPayload: Any?,
|
||||
waitForConfirmation: Boolean,
|
||||
retryable: Boolean = false
|
||||
): FlowSessionInternal {
|
||||
val session = openSessions[Pair(sessionFlow, otherParty)] ?: throw IllegalStateException("Expected an Uninitiated session for $otherParty")
|
||||
val state = session.state as? FlowSessionState.Uninitiated ?: throw IllegalStateException("Tried to initiate a session $session, but it's already initiating/initiated")
|
||||
logger.trace { "Initiating a new session with ${state.otherParty}" }
|
||||
session.state = FlowSessionState.Initiating(state.otherParty)
|
||||
session.retryable = retryable
|
||||
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
|
||||
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
|
||||
logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.")
|
||||
val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
|
||||
sendInternal(session, sessionInit)
|
||||
if (waitForConfirmation) {
|
||||
session.waitForConfirmation()
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage {
|
||||
val receivedMessage = receiveRequest.suspendAndExpectReceive()
|
||||
receivedMessage.message.confirmNoError(receiveRequest.session)
|
||||
return receivedMessage
|
||||
}
|
||||
|
||||
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
|
||||
@Suspendable
|
||||
override fun invoke(request: FlowIORequest) {
|
||||
suspend(request)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage {
|
||||
val polledMessage = session.receivedMessages.poll()
|
||||
return if (polledMessage != null) {
|
||||
if (this is SendAndReceive) {
|
||||
// Since we've already received the message, we downgrade to a send only to get the payload out and not
|
||||
// inadvertently block
|
||||
suspend(SendOnly(session, message))
|
||||
}
|
||||
polledMessage
|
||||
} else {
|
||||
// Suspend while we wait for a receive
|
||||
suspend(this)
|
||||
session.receivedMessages.poll() ?:
|
||||
throw IllegalStateException("Was expecting a message but instead got nothing for $this")
|
||||
}
|
||||
}
|
||||
|
||||
private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage {
|
||||
when (payload) {
|
||||
is ConfirmSessionMessage,
|
||||
is DataSessionMessage -> {
|
||||
return this
|
||||
}
|
||||
is ErrorSessionMessage -> {
|
||||
openSessions.values.remove(session)
|
||||
session.erroredEnd(payload.flowException)
|
||||
}
|
||||
is RejectSessionMessage -> {
|
||||
session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}"))
|
||||
}
|
||||
EndSessionMessage -> {
|
||||
openSessions.values.remove(session)
|
||||
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
|
||||
"sending data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing {
|
||||
if (exception != null) {
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
||||
exception.fillInStackTrace()
|
||||
throw exception
|
||||
} else {
|
||||
throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun suspend(ioRequest: FlowIORequest) {
|
||||
// We have to pass the thread local database transaction across via a transient field as the fiber park
|
||||
// swaps them out.
|
||||
txTrampoline = contextTransactionOrNull
|
||||
contextTransactionOrNull = null
|
||||
if (ioRequest is WaitingRequest)
|
||||
waitingForResponse = ioRequest
|
||||
|
||||
var exceptionDuringSuspend: Throwable? = null
|
||||
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
|
||||
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext))
|
||||
val transaction = extractThreadLocalTransaction()
|
||||
parkAndSerialize { _, _ ->
|
||||
logger.trace { "Suspended on $ioRequest" }
|
||||
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||
try {
|
||||
contextTransactionOrNull = txTrampoline
|
||||
txTrampoline = null
|
||||
actionOnSuspend(ioRequest)
|
||||
} catch (t: Throwable) {
|
||||
// Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
|
||||
// resume the fiber just so that we can throw it when it's running.
|
||||
exceptionDuringSuspend = t
|
||||
logger.trace("Resuming so fiber can it terminate with the exception thrown during suspend process", t)
|
||||
resume(scheduler)
|
||||
}
|
||||
}
|
||||
|
||||
if (exceptionDuringSuspend == null && ioRequest is Sleep) {
|
||||
// Sleep on the fiber. This will not sleep if it's in the past.
|
||||
Strand.sleep(Duration.between(Instant.now(), ioRequest.until).toNanos(), TimeUnit.NANOSECONDS)
|
||||
contextTransactionOrNull = transaction.value
|
||||
val event = try {
|
||||
Event.Suspend(
|
||||
ioRequest = ioRequest,
|
||||
maySkipCheckpoint = maySkipCheckpoint,
|
||||
fiber = this.serialize(context = serializationContext.value)
|
||||
)
|
||||
} catch (throwable: Throwable) {
|
||||
Event.Error(throwable)
|
||||
}
|
||||
|
||||
// We must commit the database transaction before returning from this closure otherwise Quasar may schedule
|
||||
// other fibers, so we process the event immediately
|
||||
val continuation = processEventImmediately(
|
||||
event,
|
||||
isDbTransactionOpenOnEntry = true,
|
||||
isDbTransactionOpenOnExit = false
|
||||
)
|
||||
require(continuation == FlowContinuation.ProcessEvents)
|
||||
Fiber.unparkDeserialized(this, scheduler)
|
||||
}
|
||||
createTransaction()
|
||||
// TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
|
||||
// the fiber when exceptions occur inside a suspend.
|
||||
exceptionDuringSuspend?.let { throw it }
|
||||
logger.trace { "Resumed from $ioRequest" }
|
||||
return uncheckedCast(processEventsUntilFlowIsResumed(
|
||||
isDbTransactionOpenOnEntry = false,
|
||||
isDbTransactionOpenOnExit = true
|
||||
))
|
||||
}
|
||||
|
||||
internal fun resume(scheduler: FiberScheduler) {
|
||||
try {
|
||||
if (fromCheckpoint) {
|
||||
logger.info("Resumed from checkpoint")
|
||||
fromCheckpoint = false
|
||||
Fiber.unparkDeserialized(this, scheduler)
|
||||
} else if (state == State.NEW) {
|
||||
logger.trace("Started")
|
||||
start()
|
||||
} else {
|
||||
Fiber.unpark(this, QUASAR_UNBLOCKER)
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
logger.error("Error during resume", t)
|
||||
}
|
||||
@Suspendable
|
||||
override fun scheduleEvent(event: Event) {
|
||||
getTransientField(TransientValues::eventQueue).send(event)
|
||||
}
|
||||
|
||||
override fun snapshot(): StateMachineState {
|
||||
return transientState!!.value
|
||||
}
|
||||
|
||||
override val stateMachine get() = getTransientField(TransientValues::stateMachine)
|
||||
|
||||
/**
|
||||
* Records the duration of this flow – from call() to completion or failure.
|
||||
* Note that the duration will include the time the flow spent being parked, and not just the total
|
||||
@ -582,15 +379,3 @@ val Class<out FlowLogic<*>>.appName: String
|
||||
"<unknown>"
|
||||
}
|
||||
}
|
||||
|
||||
fun <T : Any> DataSessionMessage.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||
val payloadData: T = try {
|
||||
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
|
||||
serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT)
|
||||
} catch (ex: Exception) {
|
||||
throw IOException("Payload invalid", ex)
|
||||
}
|
||||
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
|
||||
throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " +
|
||||
"${payloadData.javaClass.name} (${payloadData})")
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.loggerFor
|
||||
|
||||
/**
|
||||
* A simple [FlowHospital] implementation that immediately triggers error propagation when a flow dirties.
|
||||
*/
|
||||
object PropagatingFlowHospital : FlowHospital {
|
||||
private val log = loggerFor<PropagatingFlowHospital>()
|
||||
|
||||
override fun flowErrored(flowFiber: FlowFiber) {
|
||||
log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" }
|
||||
flowFiber.scheduleEvent(Event.StartErrorPropagation)
|
||||
}
|
||||
|
||||
override fun flowCleaned(flowFiber: FlowFiber) {
|
||||
throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered")
|
||||
}
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.CordaException
|
||||
|
||||
/**
|
||||
* An exception propagated and thrown in case a session initiation fails.
|
||||
*/
|
||||
class SessionRejectException(reason: String) : CordaException(reason)
|
@ -0,0 +1,689 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.FiberExecutorScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.fibers.instrument.SuspendableHelper
|
||||
import co.paralleluniverse.strands.channels.Channels
|
||||
import com.codahale.metrics.Gauge
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.map
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.config.shouldCheckCheckpoints
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.statemachine.interceptors.*
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
import net.corda.node.services.statemachine.transitions.StateMachineConfiguration
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
|
||||
import net.corda.nodeapi.internal.serialization.withTokenContext
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.security.SecureRandom
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
import kotlin.collections.ArrayList
|
||||
import kotlin.streams.toList
|
||||
|
||||
/**
|
||||
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
|
||||
* thread actually starts them via [startFlow].
|
||||
*/
|
||||
@ThreadSafe
|
||||
class SingleThreadedStateMachineManager(
|
||||
val serviceHub: ServiceHubInternal,
|
||||
val checkpointStorage: CheckpointStorage,
|
||||
val executor: ExecutorService,
|
||||
val database: CordaPersistence,
|
||||
val secureRandom: SecureRandom,
|
||||
private val unfinishedFibers: ReusableLatch = ReusableLatch(),
|
||||
private val classloader: ClassLoader = SingleThreadedStateMachineManager::class.java.classLoader
|
||||
) : StateMachineManager, StateMachineManagerInternal {
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
|
||||
|
||||
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
|
||||
// property.
|
||||
private class InnerState {
|
||||
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
|
||||
// True if we're shutting down, so don't resume anything.
|
||||
var stopping = false
|
||||
val flows = HashMap<StateMachineRunId, Flow>()
|
||||
val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>()
|
||||
}
|
||||
|
||||
private val mutex = ThreadBox(InnerState())
|
||||
private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor)
|
||||
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
|
||||
private val liveFibers = ReusableLatch()
|
||||
// Monitoring support.
|
||||
private val metrics = serviceHub.monitoringService.metrics
|
||||
private val sessionToFlow = ConcurrentHashMap<SessionId, StateMachineRunId>()
|
||||
private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub)
|
||||
private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null
|
||||
private val transitionExecutor = makeTransitionExecutor()
|
||||
|
||||
private var checkpointSerializationContext: SerializationContext? = null
|
||||
private var tokenizableServices: List<Any>? = null
|
||||
private var actionExecutor: ActionExecutor? = null
|
||||
|
||||
override val allStateMachines: List<FlowLogic<*>>
|
||||
get() = mutex.locked { flows.values.map { it.fiber.logic } }
|
||||
|
||||
|
||||
private val totalStartedFlows = metrics.counter("Flows.Started")
|
||||
private val totalFinishedFlows = metrics.counter("Flows.Finished")
|
||||
|
||||
/**
|
||||
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
|
||||
* which may change across restarts.
|
||||
*
|
||||
* We use assignment here so that multiple subscribers share the same wrapped Observable.
|
||||
*/
|
||||
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher
|
||||
|
||||
override fun start(tokenizableServices: List<Any>) {
|
||||
checkQuasarJavaAgentPresence()
|
||||
this.tokenizableServices = tokenizableServices
|
||||
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
|
||||
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
|
||||
)
|
||||
this.checkpointSerializationContext = checkpointSerializationContext
|
||||
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
|
||||
fiberDeserializationChecker?.start(checkpointSerializationContext)
|
||||
val fibers = restoreFlowsFromCheckpoints()
|
||||
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.flows.size })
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
|
||||
}
|
||||
serviceHub.networkMapCache.nodeReady.then {
|
||||
resumeRestoredFlows(fibers)
|
||||
flowMessaging.start { receivedMessage, deduplicationHandler ->
|
||||
executor.execute {
|
||||
onSessionMessage(receivedMessage, deduplicationHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun resume() {
|
||||
fiberDeserializationChecker?.start(checkpointSerializationContext!!)
|
||||
val fibers = restoreFlowsFromCheckpoints()
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
|
||||
}
|
||||
serviceHub.networkMapCache.nodeReady.then {
|
||||
resumeRestoredFlows(fibers)
|
||||
}
|
||||
mutex.locked {
|
||||
stopping = false
|
||||
}
|
||||
}
|
||||
|
||||
override fun <A : FlowLogic<*>> findStateMachines(flowClass: Class<A>): List<Pair<A, CordaFuture<*>>> {
|
||||
return mutex.locked {
|
||||
flows.values.mapNotNull {
|
||||
flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the shutdown process, bringing the [SingleThreadedStateMachineManager] to a controlled stop. When this method returns,
|
||||
* all Fibers have been suspended and checkpointed, or have completed.
|
||||
*
|
||||
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
|
||||
*/
|
||||
override fun stop(allowedUnsuspendedFiberCount: Int) {
|
||||
require(allowedUnsuspendedFiberCount >= 0)
|
||||
mutex.locked {
|
||||
if (stopping) throw IllegalStateException("Already stopping!")
|
||||
stopping = true
|
||||
for ((_, flow) in flows) {
|
||||
flow.fiber.scheduleEvent(Event.SoftShutdown)
|
||||
}
|
||||
}
|
||||
// Account for any expected Fibers in a test scenario.
|
||||
liveFibers.countDown(allowedUnsuspendedFiberCount)
|
||||
liveFibers.await()
|
||||
fiberDeserializationChecker?.let {
|
||||
val foundUnrestorableFibers = it.stop()
|
||||
check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
||||
* calls to [allStateMachines]
|
||||
*/
|
||||
override fun track(): DataFeed<List<FlowLogic<*>>, StateMachineManager.Change> {
|
||||
return mutex.locked {
|
||||
DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed())
|
||||
}
|
||||
}
|
||||
|
||||
override fun <A> startFlow(
|
||||
flowLogic: FlowLogic<A>,
|
||||
context: InvocationContext,
|
||||
ourIdentity: Party?,
|
||||
deduplicationHandler: DeduplicationHandler?
|
||||
): CordaFuture<FlowStateMachine<A>> {
|
||||
return startFlowInternal(
|
||||
invocationContext = context,
|
||||
flowLogic = flowLogic,
|
||||
flowStart = FlowStart.Explicit,
|
||||
ourIdentity = ourIdentity ?: getOurFirstIdentity(),
|
||||
deduplicationHandler = deduplicationHandler,
|
||||
isStartIdempotent = false
|
||||
)
|
||||
}
|
||||
|
||||
override fun killFlow(id: StateMachineRunId): Boolean {
|
||||
|
||||
return mutex.locked {
|
||||
val flow = flows.remove(id)
|
||||
if (flow != null) {
|
||||
logger.debug("Killing flow known to physical node.")
|
||||
decrementLiveFibers()
|
||||
totalFinishedFlows.inc()
|
||||
unfinishedFibers.countDown()
|
||||
try {
|
||||
flow.fiber.interrupt()
|
||||
true
|
||||
} finally {
|
||||
database.transaction {
|
||||
checkpointStorage.removeCheckpoint(id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO replace with a clustered delete after we'll support clustered nodes
|
||||
logger.debug("Unable to kill a flow unknown to physical node. Might be processed by another physical node.")
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) {
|
||||
val previousFlowId = sessionToFlow.put(sessionId, flowId)
|
||||
if (previousFlowId != null) {
|
||||
if (previousFlowId == flowId) {
|
||||
logger.warn("Session binding from $sessionId to $flowId re-added")
|
||||
} else {
|
||||
throw IllegalStateException(
|
||||
"Attempted to add session binding from session $sessionId to flow $flowId, " +
|
||||
"however there was already a binding to $previousFlowId"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun removeSessionBindings(sessionIds: Set<SessionId>) {
|
||||
val reRemovedSessionIds = HashSet<SessionId>()
|
||||
for (sessionId in sessionIds) {
|
||||
val flowId = sessionToFlow.remove(sessionId)
|
||||
if (flowId == null) {
|
||||
reRemovedSessionIds.add(sessionId)
|
||||
}
|
||||
}
|
||||
if (reRemovedSessionIds.isNotEmpty()) {
|
||||
logger.warn("Session binding from $reRemovedSessionIds re-removed")
|
||||
}
|
||||
}
|
||||
|
||||
override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) {
|
||||
mutex.locked {
|
||||
val flow = flows.remove(flowId)
|
||||
if (flow != null) {
|
||||
decrementLiveFibers()
|
||||
totalFinishedFlows.inc()
|
||||
unfinishedFibers.countDown()
|
||||
return when (removalReason) {
|
||||
is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState)
|
||||
is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState)
|
||||
FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown)
|
||||
}
|
||||
} else {
|
||||
logger.warn("Flow $flowId re-finished")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun signalFlowHasStarted(flowId: StateMachineRunId) {
|
||||
mutex.locked {
|
||||
startedFutures.remove(flowId)?.set(Unit)
|
||||
flows[flowId]?.let { flow ->
|
||||
changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val stateMachineConfiguration = StateMachineConfiguration.default
|
||||
|
||||
private fun checkQuasarJavaAgentPresence() {
|
||||
check(SuspendableHelper.isJavaAgentActive(), {
|
||||
"""Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM.
|
||||
#See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#")
|
||||
})
|
||||
}
|
||||
|
||||
private fun decrementLiveFibers() {
|
||||
liveFibers.countDown()
|
||||
}
|
||||
|
||||
private fun incrementLiveFibers() {
|
||||
liveFibers.countUp()
|
||||
}
|
||||
|
||||
private fun restoreFlowsFromCheckpoints(): List<Flow> {
|
||||
return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) ->
|
||||
// If a flow is added before start() then don't attempt to restore it
|
||||
mutex.locked { if (flows.containsKey(id)) return@map null }
|
||||
val checkpoint = deserializeCheckpoint(serializedCheckpoint)
|
||||
if (checkpoint == null) return@map null
|
||||
createFlowFromCheckpoint(
|
||||
id = id,
|
||||
checkpoint = checkpoint,
|
||||
initialDeduplicationHandler = null,
|
||||
isAnyCheckpointPersisted = true,
|
||||
isStartIdempotent = false
|
||||
)
|
||||
}.toList().filterNotNull()
|
||||
}
|
||||
|
||||
private fun resumeRestoredFlows(flows: List<Flow>) {
|
||||
for (flow in flows) {
|
||||
addAndStartFlow(flow.fiber.id, flow)
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
|
||||
val peer = message.peer
|
||||
val sessionMessage = try {
|
||||
message.data.deserialize<SessionMessage>()
|
||||
} catch (ex: Exception) {
|
||||
logger.error("Received corrupt SessionMessage data from $peer")
|
||||
deduplicationHandler.afterDatabaseTransaction()
|
||||
return
|
||||
}
|
||||
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
|
||||
if (sender != null) {
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender)
|
||||
is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender)
|
||||
}
|
||||
} else {
|
||||
logger.error("Unknown peer $peer in $sessionMessage")
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) {
|
||||
try {
|
||||
val recipientId = sessionMessage.recipientSessionId
|
||||
val flowId = sessionToFlow[recipientId]
|
||||
if (flowId == null) {
|
||||
deduplicationHandler.afterDatabaseTransaction()
|
||||
if (sessionMessage.payload is EndSessionMessage) {
|
||||
logger.debug {
|
||||
"Got ${EndSessionMessage::class.java.simpleName} for " +
|
||||
"unknown session $recipientId, discarding..."
|
||||
}
|
||||
} else {
|
||||
throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId")
|
||||
}
|
||||
} else {
|
||||
val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
|
||||
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
|
||||
}
|
||||
} catch (exception: Exception) {
|
||||
logger.error("Exception while routing $sessionMessage", exception)
|
||||
throw exception
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) {
|
||||
fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage {
|
||||
val errorId = secureRandom.nextLong()
|
||||
val payload = RejectSessionMessage(message, errorId)
|
||||
return ExistingSessionMessage(initiatorSessionId, payload)
|
||||
}
|
||||
val replyError = try {
|
||||
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
|
||||
val initiatedSessionId = SessionId.createRandom(secureRandom)
|
||||
val senderSession = FlowSessionImpl(sender, initiatedSessionId)
|
||||
val flowLogic = initiatedFlowFactory.createFlow(senderSession)
|
||||
val initiatedFlowInfo = when (initiatedFlowFactory) {
|
||||
is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda")
|
||||
is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName)
|
||||
}
|
||||
val senderCoreFlowVersion = when (initiatedFlowFactory) {
|
||||
is InitiatedFlowFactory.Core -> senderPlatformVersion
|
||||
is InitiatedFlowFactory.CorDapp -> null
|
||||
}
|
||||
startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo)
|
||||
null
|
||||
} catch (exception: Exception) {
|
||||
logger.warn("Exception while creating initiated flow", exception)
|
||||
createErrorMessage(
|
||||
sessionMessage.initiatorSessionId,
|
||||
(exception as? SessionRejectException)?.message ?: "Unable to establish session"
|
||||
)
|
||||
}
|
||||
|
||||
if (replyError != null) {
|
||||
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
|
||||
deduplicationHandler.afterDatabaseTransaction()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO this is a temporary hack until we figure out multiple identities
|
||||
private fun getOurFirstIdentity(): Party {
|
||||
return serviceHub.myInfo.legalIdentities[0]
|
||||
}
|
||||
|
||||
private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> {
|
||||
val initiatingFlowClass = try {
|
||||
Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
|
||||
} catch (e: ClassNotFoundException) {
|
||||
throw SessionRejectException("Don't know ${message.initiatorFlowClassName}")
|
||||
} catch (e: ClassCastException) {
|
||||
throw SessionRejectException("${message.initiatorFlowClassName} is not a flow")
|
||||
}
|
||||
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
|
||||
throw SessionRejectException("$initiatingFlowClass is not registered")
|
||||
}
|
||||
|
||||
private fun <A> startInitiatedFlow(
|
||||
flowLogic: FlowLogic<A>,
|
||||
initiatingMessageDeduplicationHandler: DeduplicationHandler,
|
||||
peerSession: FlowSessionImpl,
|
||||
initiatedSessionId: SessionId,
|
||||
initiatingMessage: InitialSessionMessage,
|
||||
senderCoreFlowVersion: Int?,
|
||||
initiatedFlowInfo: FlowInfo
|
||||
) {
|
||||
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo)
|
||||
val ourIdentity = getOurFirstIdentity()
|
||||
startFlowInternal(
|
||||
InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity,
|
||||
initiatingMessageDeduplicationHandler,
|
||||
isStartIdempotent = false
|
||||
)
|
||||
}
|
||||
|
||||
private fun <A> startFlowInternal(
|
||||
invocationContext: InvocationContext,
|
||||
flowLogic: FlowLogic<A>,
|
||||
flowStart: FlowStart,
|
||||
ourIdentity: Party,
|
||||
deduplicationHandler: DeduplicationHandler?,
|
||||
isStartIdempotent: Boolean
|
||||
): CordaFuture<FlowStateMachine<A>> {
|
||||
val flowId = StateMachineRunId.createRandom()
|
||||
val deduplicationSeed = when (flowStart) {
|
||||
FlowStart.Explicit -> flowId.uuid.toString()
|
||||
is FlowStart.Initiated ->
|
||||
"${flowStart.initiatingMessage.initiatorSessionId.toLong}-" +
|
||||
"${flowStart.initiatingMessage.initiationEntropy}"
|
||||
}
|
||||
|
||||
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
|
||||
// have access to the fiber (and thereby the service hub)
|
||||
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
|
||||
val resultFuture = openFuture<Any?>()
|
||||
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
|
||||
flowLogic.stateMachine = flowStateMachineImpl
|
||||
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
|
||||
|
||||
val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow()
|
||||
val startedFuture = openFuture<Unit>()
|
||||
val initialState = StateMachineState(
|
||||
checkpoint = initialCheckpoint,
|
||||
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isTransactionTracked = false,
|
||||
isAnyCheckpointPersisted = false,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = flowLogic
|
||||
)
|
||||
flowStateMachineImpl.transientState = TransientReference(initialState)
|
||||
mutex.locked {
|
||||
startedFutures[flowId] = startedFuture
|
||||
}
|
||||
totalStartedFlows.inc()
|
||||
addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture))
|
||||
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
|
||||
}
|
||||
|
||||
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
|
||||
return try {
|
||||
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
|
||||
} catch (exception: Throwable) {
|
||||
logger.error("Encountered unrestorable checkpoint!", exception)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
|
||||
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
|
||||
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
|
||||
//
|
||||
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
|
||||
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
|
||||
// bridge, so we have to do a more complex scan here.
|
||||
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
|
||||
if (call.getAnnotation(Suspendable::class.java) == null) {
|
||||
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
|
||||
return FlowStateMachineImpl.TransientValues(
|
||||
eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK),
|
||||
resultFuture = resultFuture,
|
||||
database = database,
|
||||
transitionExecutor = transitionExecutor,
|
||||
actionExecutor = actionExecutor!!,
|
||||
stateMachine = StateMachine(id, stateMachineConfiguration, secureRandom),
|
||||
serviceHub = serviceHub,
|
||||
checkpointSerializationContext = checkpointSerializationContext!!
|
||||
)
|
||||
}
|
||||
|
||||
private fun createFlowFromCheckpoint(
|
||||
id: StateMachineRunId,
|
||||
checkpoint: Checkpoint,
|
||||
isAnyCheckpointPersisted: Boolean,
|
||||
isStartIdempotent: Boolean,
|
||||
initialDeduplicationHandler: DeduplicationHandler?
|
||||
): Flow {
|
||||
val flowState = checkpoint.flowState
|
||||
val resultFuture = openFuture<Any?>()
|
||||
val fiber = when (flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
|
||||
val state = StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isTransactionTracked = false,
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = logic
|
||||
)
|
||||
val fiber = FlowStateMachineImpl(id, logic, scheduler)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
fiber.transientState = TransientReference(state)
|
||||
fiber.logic.stateMachine = fiber
|
||||
fiber
|
||||
}
|
||||
is FlowState.Started -> {
|
||||
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
|
||||
val state = StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isTransactionTracked = false,
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
flowLogic = fiber.logic
|
||||
)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
fiber.transientState = TransientReference(state)
|
||||
fiber.logic.stateMachine = fiber
|
||||
fiber
|
||||
}
|
||||
}
|
||||
|
||||
verifyFlowLogicIsSuspendable(fiber.logic)
|
||||
|
||||
return Flow(fiber, resultFuture)
|
||||
}
|
||||
|
||||
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) {
|
||||
val checkpoint = flow.fiber.snapshot().checkpoint
|
||||
for (sessionId in getFlowSessionIds(checkpoint)) {
|
||||
sessionToFlow.put(sessionId, id)
|
||||
}
|
||||
mutex.locked {
|
||||
if (stopping) {
|
||||
startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping"))
|
||||
logger.trace("Not resuming as SMM is stopping.")
|
||||
} else {
|
||||
incrementLiveFibers()
|
||||
unfinishedFibers.countUp()
|
||||
flows.put(id, flow)
|
||||
flow.fiber.scheduleEvent(Event.DoRemainingWork)
|
||||
when (checkpoint.flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
flow.fiber.start()
|
||||
}
|
||||
is FlowState.Started -> {
|
||||
Fiber.unparkDeserialized(flow.fiber, scheduler)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> {
|
||||
val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated
|
||||
return if (initiatedFlowStart == null) {
|
||||
checkpoint.sessions.keys
|
||||
} else {
|
||||
checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId
|
||||
}
|
||||
}
|
||||
|
||||
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
|
||||
return ActionExecutorImpl(
|
||||
serviceHub,
|
||||
checkpointStorage,
|
||||
flowMessaging,
|
||||
this,
|
||||
checkpointSerializationContext,
|
||||
metrics
|
||||
)
|
||||
}
|
||||
|
||||
private fun makeTransitionExecutor(): TransitionExecutor {
|
||||
val interceptors = ArrayList<TransitionInterceptor>()
|
||||
interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) }
|
||||
if (serviceHub.configuration.devMode) {
|
||||
interceptors.add { DumpHistoryOnErrorInterceptor(it) }
|
||||
}
|
||||
if (serviceHub.configuration.shouldCheckCheckpoints()) {
|
||||
interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) }
|
||||
}
|
||||
if (logger.isDebugEnabled) {
|
||||
interceptors.add { PrintingInterceptor(it) }
|
||||
}
|
||||
val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database)
|
||||
return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) }
|
||||
}
|
||||
|
||||
private fun InnerState.removeFlowOrderly(
|
||||
flow: Flow,
|
||||
removalReason: FlowRemovalReason.OrderlyFinish,
|
||||
lastState: StateMachineState
|
||||
) {
|
||||
drainFlowEventQueue(flow)
|
||||
// final sanity checks
|
||||
require(lastState.pendingDeduplicationHandlers.isEmpty())
|
||||
require(lastState.isRemoved)
|
||||
require(lastState.checkpoint.subFlowStack.size == 1)
|
||||
sessionToFlow.none { it.value == flow.fiber.id }
|
||||
flow.resultFuture.set(removalReason.flowReturnValue)
|
||||
lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue)))
|
||||
}
|
||||
|
||||
private fun InnerState.removeFlowError(
|
||||
flow: Flow,
|
||||
removalReason: FlowRemovalReason.ErrorFinish,
|
||||
lastState: StateMachineState
|
||||
) {
|
||||
drainFlowEventQueue(flow)
|
||||
val flowError = removalReason.flowErrors[0] // TODO what to do with several?
|
||||
val exception = flowError.exception
|
||||
(exception as? FlowException)?.originalErrorId = flowError.errorId
|
||||
flow.resultFuture.setException(exception)
|
||||
lastState.flowLogic.progressTracker?.endWithError(exception)
|
||||
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure<Nothing>(exception)))
|
||||
}
|
||||
|
||||
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
|
||||
private fun drainFlowEventQueue(flow: Flow) {
|
||||
while (true) {
|
||||
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
|
||||
when (event) {
|
||||
is Event.DoRemainingWork -> {}
|
||||
is Event.DeliverSessionMessage -> {
|
||||
// Acknowledge the message so it doesn't leak in the broker.
|
||||
event.deduplicationHandler.afterDatabaseTransaction()
|
||||
when (event.sessionMessage.payload) {
|
||||
EndSessionMessage -> {
|
||||
logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" }
|
||||
}
|
||||
else -> {
|
||||
logger.warn("Unhandled message ${event.sessionMessage} due to flow shutting down")
|
||||
}
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
logger.warn("Unhandled event $event due to flow shutting down")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,11 +1,14 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import rx.Observable
|
||||
|
||||
/**
|
||||
@ -23,7 +26,6 @@ import rx.Observable
|
||||
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
|
||||
* TODO: Timeouts
|
||||
* TODO: Surfacing of exceptions via an API and/or management UI
|
||||
* TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt
|
||||
* TODO: Don't store all active flows in memory, load from the database on demand.
|
||||
*/
|
||||
interface StateMachineManager {
|
||||
@ -37,13 +39,25 @@ interface StateMachineManager {
|
||||
*/
|
||||
fun stop(allowedUnsuspendedFiberCount: Int)
|
||||
|
||||
/**
|
||||
* Resume state machine manager after having called [stop].
|
||||
*/
|
||||
fun resume()
|
||||
|
||||
/**
|
||||
* Starts a new flow.
|
||||
*
|
||||
* @param flowLogic The flow's code.
|
||||
* @param context The context of the flow.
|
||||
* @param ourIdentity The identity to use for the flow.
|
||||
* @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler].
|
||||
*/
|
||||
fun <A> startFlow(flowLogic: FlowLogic<A>, context: InvocationContext): CordaFuture<FlowStateMachine<A>>
|
||||
fun <A> startFlow(
|
||||
flowLogic: FlowLogic<A>,
|
||||
context: InvocationContext,
|
||||
ourIdentity: Party?,
|
||||
deduplicationHandler: DeduplicationHandler?
|
||||
): CordaFuture<FlowStateMachine<A>>
|
||||
|
||||
/**
|
||||
* Represents an addition/removal of a state machine.
|
||||
@ -73,4 +87,20 @@ interface StateMachineManager {
|
||||
* Returns all currently live flows.
|
||||
*/
|
||||
val allStateMachines: List<FlowLogic<*>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to kill a flow. This is not a clean termination and should be reserved for exceptional cases such as stuck fibers.
|
||||
*
|
||||
* @return whether the flow existed and was killed.
|
||||
*/
|
||||
fun killFlow(id: StateMachineRunId): Boolean
|
||||
}
|
||||
|
||||
// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call
|
||||
// these functions again
|
||||
interface StateMachineManagerInternal {
|
||||
fun signalFlowHasStarted(flowId: StateMachineRunId)
|
||||
fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId)
|
||||
fun removeSessionBindings(sessionIds: Set<SessionId>)
|
||||
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
|
||||
}
|
||||
|
@ -1,666 +0,0 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.FiberExecutorScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.fibers.instrument.SuspendableHelper
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import com.codahale.metrics.Gauge
|
||||
import com.esotericsoftware.kryo.KryoException
|
||||
import com.google.common.collect.HashMultimap
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import net.corda.core.CordaException
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.context.InvocationOrigin
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.concurrent.doneFuture
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
||||
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.services.api.Checkpoint
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.config.shouldCheckCheckpoints
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.newNamedSingleThreadExecutor
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
|
||||
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
|
||||
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
|
||||
import net.corda.nodeapi.internal.serialization.withTokenContext
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import org.slf4j.Logger
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.io.NotSerializableException
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.TimeUnit.SECONDS
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
/**
|
||||
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
|
||||
* thread actually starts them via [startFlow].
|
||||
*/
|
||||
@ThreadSafe
|
||||
class StateMachineManagerImpl(
|
||||
val serviceHub: ServiceHubInternal,
|
||||
val checkpointStorage: CheckpointStorage,
|
||||
val executor: AffinityExecutor,
|
||||
val database: CordaPersistence,
|
||||
private val unfinishedFibers: ReusableLatch = ReusableLatch(),
|
||||
private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader
|
||||
) : StateMachineManager {
|
||||
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
|
||||
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
internal val sessionTopic = "platform.session"
|
||||
|
||||
init {
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
|
||||
// property.
|
||||
private class InnerState {
|
||||
var started = false
|
||||
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
|
||||
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
|
||||
val fibersWaitingForLedgerCommit = HashMultimap.create<SecureHash, FlowStateMachineImpl<*>>()!!
|
||||
|
||||
fun notifyChangeObservers(change: StateMachineManager.Change) {
|
||||
changesPublisher.bufferUntilDatabaseCommit().onNext(change)
|
||||
}
|
||||
}
|
||||
|
||||
private val scheduler = FiberScheduler()
|
||||
private val mutex = ThreadBox(InnerState())
|
||||
// This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore.
|
||||
private val checkpointCheckerThread = if (serviceHub.configuration.shouldCheckCheckpoints()) {
|
||||
newNamedSingleThreadExecutor("CheckpointChecker")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
@Volatile private var unrestorableCheckpoints = false
|
||||
|
||||
// True if we're shutting down, so don't resume anything.
|
||||
@Volatile private var stopping = false
|
||||
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
|
||||
private val liveFibers = ReusableLatch()
|
||||
|
||||
// Monitoring support.
|
||||
private val metrics = serviceHub.monitoringService.metrics
|
||||
|
||||
init {
|
||||
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.stateMachines.size })
|
||||
}
|
||||
|
||||
private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate")
|
||||
private val totalStartedFlows = metrics.counter("Flows.Started")
|
||||
private val totalFinishedFlows = metrics.counter("Flows.Finished")
|
||||
|
||||
private val openSessions = ConcurrentHashMap<SessionId, FlowSessionInternal>()
|
||||
private val recentlyClosedSessions = ConcurrentHashMap<SessionId, Party>()
|
||||
|
||||
// Context for tokenized services in checkpoints
|
||||
private lateinit var tokenizableServices: List<Any>
|
||||
private val serializationContext by lazy {
|
||||
SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub)
|
||||
}
|
||||
|
||||
/** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */
|
||||
override fun <A : FlowLogic<*>> findStateMachines(flowClass: Class<A>): List<Pair<A, CordaFuture<*>>> {
|
||||
return mutex.locked {
|
||||
stateMachines.keys.mapNotNull {
|
||||
flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast<FlowStateMachine<*>, FlowStateMachineImpl<*>>(it.stateMachine).resultFuture }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override val allStateMachines: List<FlowLogic<*>>
|
||||
get() = mutex.locked { stateMachines.keys.map { it.logic } }
|
||||
|
||||
/**
|
||||
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
|
||||
* which may change across restarts.
|
||||
*
|
||||
* We use assignment here so that multiple subscribers share the same wrapped Observable.
|
||||
*/
|
||||
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher.wrapWithDatabaseTransaction()
|
||||
|
||||
override fun start(tokenizableServices: List<Any>) {
|
||||
this.tokenizableServices = tokenizableServices
|
||||
checkQuasarJavaAgentPresence()
|
||||
restoreFibersFromCheckpoints()
|
||||
listenToLedgerTransactions()
|
||||
serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) }
|
||||
}
|
||||
|
||||
private fun checkQuasarJavaAgentPresence() {
|
||||
check(SuspendableHelper.isJavaAgentActive(), {
|
||||
"""Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM.
|
||||
#See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#")
|
||||
})
|
||||
}
|
||||
|
||||
private fun listenToLedgerTransactions() {
|
||||
// Observe the stream of committed, validated transactions and resume fibers that are waiting for them.
|
||||
serviceHub.validatedTransactions.updates.subscribe { stx ->
|
||||
val hash = stx.id
|
||||
val fibers: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) }
|
||||
if (fibers.isNotEmpty()) {
|
||||
executor.executeASAP {
|
||||
for (fiber in fibers) {
|
||||
fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" }
|
||||
fiber.waitingForResponse = null
|
||||
resumeFiber(fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun decrementLiveFibers() {
|
||||
liveFibers.countDown()
|
||||
}
|
||||
|
||||
private fun incrementLiveFibers() {
|
||||
liveFibers.countUp()
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns,
|
||||
* all Fibers have been suspended and checkpointed, or have completed.
|
||||
*
|
||||
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
|
||||
*/
|
||||
override fun stop(allowedUnsuspendedFiberCount: Int) {
|
||||
require(allowedUnsuspendedFiberCount >= 0)
|
||||
mutex.locked {
|
||||
if (stopping) throw IllegalStateException("Already stopping!")
|
||||
stopping = true
|
||||
}
|
||||
// Account for any expected Fibers in a test scenario.
|
||||
liveFibers.countDown(allowedUnsuspendedFiberCount)
|
||||
liveFibers.await()
|
||||
checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) }
|
||||
check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." }
|
||||
scheduler.shutdown()
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
||||
* calls to [allStateMachines]
|
||||
*/
|
||||
override fun track(): DataFeed<List<FlowLogic<*>>, StateMachineManager.Change> {
|
||||
return mutex.locked {
|
||||
DataFeed(stateMachines.keys.map { it.logic }, changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
|
||||
}
|
||||
}
|
||||
|
||||
private fun restoreFibersFromCheckpoints() {
|
||||
mutex.locked {
|
||||
checkpointStorage.forEach { checkpoint ->
|
||||
// If a flow is added before start() then don't attempt to restore it
|
||||
if (!stateMachines.containsValue(checkpoint)) {
|
||||
deserializeFiber(checkpoint, logger)?.let {
|
||||
initFiber(it)
|
||||
stateMachines[it] = checkpoint
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFibers() {
|
||||
mutex.locked {
|
||||
started = true
|
||||
stateMachines.keys.forEach { resumeRestoredFiber(it) }
|
||||
}
|
||||
serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ ->
|
||||
executor.checkOnThread()
|
||||
onSessionMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) {
|
||||
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
|
||||
val waitingForResponse = fiber.waitingForResponse
|
||||
if (waitingForResponse != null) {
|
||||
if (waitingForResponse is WaitForLedgerCommit) {
|
||||
val stx = database.transaction {
|
||||
serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash)
|
||||
}
|
||||
if (stx != null) {
|
||||
fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed")
|
||||
fiber.waitingForResponse = null
|
||||
resumeFiber(fiber)
|
||||
} else {
|
||||
fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}")
|
||||
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) }
|
||||
}
|
||||
} else {
|
||||
fiber.logger.info("Restored, pending on receive")
|
||||
}
|
||||
} else {
|
||||
resumeFiber(fiber)
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionMessage(message: ReceivedMessage) {
|
||||
val peer = message.peer
|
||||
val sessionMessage = try {
|
||||
message.data.deserialize<SessionMessage>()
|
||||
} catch (ex: Exception) {
|
||||
logger.error("Received corrupt SessionMessage data from $peer")
|
||||
return
|
||||
}
|
||||
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
|
||||
if (sender != null) {
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
|
||||
is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender)
|
||||
}
|
||||
} else {
|
||||
logger.error("Unknown peer $peer in $sessionMessage")
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
|
||||
val session = openSessions[message.recipientSessionId]
|
||||
if (session != null) {
|
||||
session.fiber.pushToLoggingContext()
|
||||
session.fiber.logger.trace { "Received $message on $session from $sender" }
|
||||
if (session.retryable) {
|
||||
if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) {
|
||||
session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" }
|
||||
return
|
||||
}
|
||||
if (message.payload !is ConfirmSessionMessage) {
|
||||
serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong)
|
||||
}
|
||||
}
|
||||
if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) {
|
||||
openSessions.remove(message.recipientSessionId)
|
||||
}
|
||||
session.receivedMessages += ReceivedSessionMessage(sender, message)
|
||||
if (resumeOnMessage(message, session)) {
|
||||
// It's important that we reset here and not after the fiber's resumed, in case we receive another message
|
||||
// before then.
|
||||
session.fiber.waitingForResponse = null
|
||||
updateCheckpoint(session.fiber)
|
||||
session.fiber.logger.trace { "Resuming due to $message" }
|
||||
resumeFiber(session.fiber)
|
||||
}
|
||||
} else {
|
||||
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
|
||||
if (peerParty != null) {
|
||||
if (message.payload is ConfirmSessionMessage) {
|
||||
logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" }
|
||||
sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage))
|
||||
} else {
|
||||
logger.trace { "Ignoring session end message for already closed session: $message" }
|
||||
}
|
||||
} else {
|
||||
logger.warn("Received a session message for unknown session: $message, from $sender")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger
|
||||
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
|
||||
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
|
||||
val waitingForResponse = session.fiber.waitingForResponse
|
||||
return waitingForResponse?.shouldResume(message, session) ?: false
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) {
|
||||
|
||||
logger.trace { "Received $sessionInit from $sender" }
|
||||
val senderSessionId = sessionInit.initiatorSessionId
|
||||
|
||||
fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, errorId = sessionInit.initiatorSessionId.toLong)))
|
||||
|
||||
val (session, initiatedFlowFactory) = try {
|
||||
val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit)
|
||||
val flowSession = FlowSessionImpl(sender)
|
||||
val flow = initiatedFlowFactory.createFlow(flowSession)
|
||||
val senderFlowVersion = when (initiatedFlowFactory) {
|
||||
is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version
|
||||
is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion
|
||||
}
|
||||
val session = FlowSessionInternal(
|
||||
flow,
|
||||
flowSession,
|
||||
SessionId.createRandom(newSecureRandom()),
|
||||
sender,
|
||||
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
|
||||
if (sessionInit.firstPayload != null) {
|
||||
session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload)))
|
||||
}
|
||||
openSessions[session.ourSessionId] = session
|
||||
val context = InvocationContext.peer(sender.name)
|
||||
val fiber = createFiber(flow, context)
|
||||
fiber.pushToLoggingContext()
|
||||
logger.info("Accepting flow session from party ${sender.name}. Session id for tracing purposes is ${sessionInit.initiatorSessionId}.")
|
||||
flowSession.sessionFlow = flow
|
||||
flowSession.stateMachine = fiber
|
||||
fiber.openSessions[Pair(flow, sender)] = session
|
||||
updateCheckpoint(fiber)
|
||||
session to initiatedFlowFactory
|
||||
} catch (e: SessionRejectException) {
|
||||
logger.warn("${e.logMessage}: $sessionInit")
|
||||
sendSessionReject(e.rejectMessage)
|
||||
return
|
||||
} catch (e: Exception) {
|
||||
logger.warn("Couldn't start flow session from $sessionInit", e)
|
||||
sendSessionReject("Unable to establish session")
|
||||
return
|
||||
}
|
||||
|
||||
val (ourFlowVersion, appName) = when (initiatedFlowFactory) {
|
||||
// The flow version for the core flows is the platform version
|
||||
is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda"
|
||||
is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName
|
||||
}
|
||||
|
||||
sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber)
|
||||
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" }
|
||||
session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
|
||||
resumeFiber(session.fiber)
|
||||
}
|
||||
|
||||
private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> {
|
||||
val initiatingFlowClass = try {
|
||||
Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
|
||||
} catch (e: ClassNotFoundException) {
|
||||
throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}")
|
||||
} catch (e: ClassCastException) {
|
||||
throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow")
|
||||
}
|
||||
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
|
||||
throw SessionRejectException("$initiatingFlowClass is not registered")
|
||||
}
|
||||
|
||||
private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes<FlowStateMachineImpl<*>> {
|
||||
return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext))
|
||||
}
|
||||
|
||||
private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? {
|
||||
return try {
|
||||
checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply {
|
||||
fromCheckpoint = true
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
logger.error("Encountered unrestorable checkpoint!", t)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T> createFiber(logic: FlowLogic<T>, context: InvocationContext, ourIdentity: Party? = null): FlowStateMachineImpl<T> {
|
||||
val fsm = FlowStateMachineImpl(
|
||||
StateMachineRunId.createRandom(),
|
||||
logic,
|
||||
scheduler,
|
||||
ourIdentity ?: serviceHub.myInfo.legalIdentities[0],
|
||||
context)
|
||||
initFiber(fsm)
|
||||
return fsm
|
||||
}
|
||||
|
||||
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
|
||||
verifyFlowLogicIsSuspendable(fiber.logic)
|
||||
fiber.database = database
|
||||
fiber.serviceHub = serviceHub
|
||||
fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity }
|
||||
?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity.name}) is not one of ours!")
|
||||
fiber.actionOnSuspend = { ioRequest ->
|
||||
updateCheckpoint(fiber)
|
||||
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
|
||||
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||
fiber.commitTransaction()
|
||||
processIORequest(ioRequest)
|
||||
decrementLiveFibers()
|
||||
}
|
||||
fiber.actionOnEnd = { result, propagated ->
|
||||
try {
|
||||
mutex.locked {
|
||||
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
notifyChangeObservers(StateMachineManager.Change.Removed(fiber.logic, result))
|
||||
}
|
||||
endAllFiberSessions(fiber, result, propagated)
|
||||
} finally {
|
||||
fiber.commitTransaction()
|
||||
decrementLiveFibers()
|
||||
totalFinishedFlows.inc()
|
||||
unfinishedFibers.countDown()
|
||||
}
|
||||
}
|
||||
mutex.locked {
|
||||
totalStartedFlows.inc()
|
||||
unfinishedFibers.countUp()
|
||||
notifyChangeObservers(StateMachineManager.Change.Add(fiber.logic))
|
||||
}
|
||||
}
|
||||
|
||||
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
|
||||
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
|
||||
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
|
||||
//
|
||||
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
|
||||
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
|
||||
// bridge, so we have to do a more complex scan here.
|
||||
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
|
||||
if (call.getAnnotation(Suspendable::class.java) == null) {
|
||||
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
|
||||
}
|
||||
}
|
||||
|
||||
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) {
|
||||
openSessions.values.removeIf { session ->
|
||||
if (session.fiber == fiber) {
|
||||
session.endSession(fiber.context, (result as? Try.Failure)?.exception, propagated)
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) {
|
||||
val initiatedState = state as? FlowSessionState.Initiated ?: return
|
||||
val sessionEnd = if (exception == null) {
|
||||
EndSessionMessage
|
||||
} else {
|
||||
val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
|
||||
// Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
|
||||
// pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with.
|
||||
exception
|
||||
} else {
|
||||
null
|
||||
}
|
||||
ErrorSessionMessage(errorResponse, 0)
|
||||
}
|
||||
sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber)
|
||||
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
|
||||
}
|
||||
|
||||
/**
|
||||
* Kicks off a brand new state machine of the given class.
|
||||
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
|
||||
* restarted with checkpointed state machines in the storage service.
|
||||
*
|
||||
* Note that you must be on the [executor] thread.
|
||||
*/
|
||||
override fun <A> startFlow(flowLogic: FlowLogic<A>, context: InvocationContext): CordaFuture<FlowStateMachine<A>> {
|
||||
// TODO: Check that logic has @Suspendable on its call method.
|
||||
executor.checkOnThread()
|
||||
val fiber = database.transaction {
|
||||
val fiber = createFiber(flowLogic, context)
|
||||
updateCheckpoint(fiber)
|
||||
fiber
|
||||
}
|
||||
// If we are not started then our checkpoint will be picked up during start
|
||||
mutex.locked {
|
||||
if (started) {
|
||||
resumeFiber(fiber)
|
||||
}
|
||||
}
|
||||
return doneFuture(fiber)
|
||||
}
|
||||
|
||||
private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) {
|
||||
check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
||||
val newCheckpoint = Checkpoint(serializeFiber(fiber))
|
||||
val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) }
|
||||
if (previousCheckpoint != null) {
|
||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||
}
|
||||
checkpointStorage.addCheckpoint(newCheckpoint)
|
||||
checkpointingMeter.mark()
|
||||
|
||||
checkpointCheckerThread?.execute {
|
||||
// Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have
|
||||
// in our testing by failing any test where unrestorable checkpoints are created.
|
||||
if (deserializeFiber(newCheckpoint, fiber.logger) == null) {
|
||||
unrestorableCheckpoints = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeFiber(fiber: FlowStateMachineImpl<*>) {
|
||||
// Avoid race condition when setting stopping to true and then checking liveFibers
|
||||
incrementLiveFibers()
|
||||
if (!stopping) {
|
||||
executor.executeASAP {
|
||||
fiber.resume(scheduler)
|
||||
}
|
||||
} else {
|
||||
fiber.logger.trace("Not resuming as SMM is stopping.")
|
||||
decrementLiveFibers()
|
||||
}
|
||||
}
|
||||
|
||||
private fun processIORequest(ioRequest: FlowIORequest) {
|
||||
executor.checkOnThread()
|
||||
when (ioRequest) {
|
||||
is SendRequest -> processSendRequest(ioRequest)
|
||||
is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest)
|
||||
is Sleep -> processSleepRequest(ioRequest)
|
||||
}
|
||||
}
|
||||
|
||||
private fun processSendRequest(ioRequest: SendRequest) {
|
||||
val retryId = if (ioRequest.message is InitialSessionMessage) {
|
||||
with(ioRequest.session) {
|
||||
openSessions[ourSessionId] = this
|
||||
if (retryable) ourSessionId.toLong else null
|
||||
}
|
||||
} else null
|
||||
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId)
|
||||
if (ioRequest !is ReceiveRequest) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
resumeFiber(ioRequest.session.fiber)
|
||||
}
|
||||
}
|
||||
|
||||
private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) {
|
||||
// Is it already committed?
|
||||
val stx = database.transaction {
|
||||
serviceHub.validatedTransactions.getTransaction(ioRequest.hash)
|
||||
}
|
||||
if (stx != null) {
|
||||
resumeFiber(ioRequest.fiber)
|
||||
} else {
|
||||
// No, then register to wait.
|
||||
//
|
||||
// We assume this code runs on the server thread, which is the only place transactions are committed
|
||||
// currently. When we liberalise our threading somewhat, handing of wait requests will need to be
|
||||
// reworked to make the wait atomic in another way. Otherwise there is a race between checking the
|
||||
// database and updating the waiting list.
|
||||
mutex.locked {
|
||||
fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun processSleepRequest(ioRequest: Sleep) {
|
||||
// Resume the fiber now we have checkpointed, so we can sleep on the Fiber.
|
||||
resumeFiber(ioRequest.fiber)
|
||||
}
|
||||
|
||||
private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) {
|
||||
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
|
||||
?: throw IllegalArgumentException("Don't know about party $party")
|
||||
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
||||
val logger = fiber?.logger ?: logger
|
||||
logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" }
|
||||
|
||||
val serialized = try {
|
||||
message.serialize()
|
||||
} catch (e: Exception) {
|
||||
when (e) {
|
||||
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
|
||||
is KryoException,
|
||||
is NotSerializableException -> {
|
||||
if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) {
|
||||
logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " +
|
||||
"Instead sending back an exception which is serialisable to ensure session end occurs properly.", e)
|
||||
// The subclass may have overridden toString so we use that
|
||||
val exMessage = message.payload.flowException.message
|
||||
message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize()
|
||||
} else {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
else -> throw e
|
||||
}
|
||||
}
|
||||
|
||||
// This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator.
|
||||
// It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case.
|
||||
val additionalHeaders = if (mightDeadlockDrainingSender(fiber, party)) emptyMap() else message.additionalHeaders()
|
||||
serviceHub.networkService.apply {
|
||||
send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId, additionalHeaders = additionalHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
private fun mightDeadlockDrainingSender(fiber: FlowStateMachineImpl<*>?, target: Party): Boolean {
|
||||
return fiber?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name }
|
||||
}
|
||||
}
|
||||
|
||||
private fun SessionMessage.additionalHeaders(): Map<String, String> {
|
||||
return when (this) {
|
||||
is InitialSessionMessage -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE)
|
||||
else -> emptyMap()
|
||||
}
|
||||
}
|
||||
|
||||
class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) {
|
||||
constructor(message: String) : this(message, message)
|
||||
}
|
@ -0,0 +1,231 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
|
||||
/**
|
||||
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
|
||||
* persisted to the database ([Checkpoint]), and the rest, which is an in-memory-only state.
|
||||
*
|
||||
* @param checkpoint the persisted part of the state.
|
||||
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
|
||||
* @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
|
||||
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
|
||||
* to make [Event.DoRemainingWork] idempotent.
|
||||
* @param isTransactionTracked true if a ledger transaction has been tracked as part of a
|
||||
* [FlowIORequest.WaitForLedgerCommit]. This used is to make tracking idempotent.
|
||||
* @param isAnyCheckpointPersisted true if at least a single checkpoint has been persisted. This is used to determine
|
||||
* whether we should DELETE the checkpoint at the end of the flow.
|
||||
* @param isStartIdempotent true if the start of the flow is idempotent, making the skipping of the initial checkpoint
|
||||
* possible.
|
||||
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
|
||||
* work.
|
||||
*/
|
||||
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
|
||||
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
|
||||
data class StateMachineState(
|
||||
val checkpoint: Checkpoint,
|
||||
val flowLogic: FlowLogic<*>,
|
||||
val pendingDeduplicationHandlers: List<DeduplicationHandler>,
|
||||
val isFlowResumed: Boolean,
|
||||
val isTransactionTracked: Boolean,
|
||||
val isAnyCheckpointPersisted: Boolean,
|
||||
val isStartIdempotent: Boolean,
|
||||
val isRemoved: Boolean
|
||||
)
|
||||
|
||||
/**
|
||||
* @param invocationContext the initiator of the flow.
|
||||
* @param ourIdentity the identity the flow is run as.
|
||||
* @param sessions map of source session ID to session state.
|
||||
* @param subFlowStack the stack of currently executing subflows.
|
||||
* @param flowState the state of the flow itself, including the frozen fiber/FlowLogic.
|
||||
* @param errorState the "dirtiness" state including the involved errors and their propagation status.
|
||||
* @param numberOfSuspends the number of flow suspends due to IO API calls.
|
||||
* @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs.
|
||||
*/
|
||||
data class Checkpoint(
|
||||
val invocationContext: InvocationContext,
|
||||
val ourIdentity: Party,
|
||||
val sessions: SessionMap, // This must preserve the insertion order!
|
||||
val subFlowStack: List<SubFlow>,
|
||||
val flowState: FlowState,
|
||||
val errorState: ErrorState,
|
||||
val numberOfSuspends: Int,
|
||||
val deduplicationSeed: String
|
||||
) {
|
||||
companion object {
|
||||
|
||||
fun create(
|
||||
invocationContext: InvocationContext,
|
||||
flowStart: FlowStart,
|
||||
flowLogicClass: Class<FlowLogic<*>>,
|
||||
frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
|
||||
ourIdentity: Party,
|
||||
deduplicationSeed: String
|
||||
): Try<Checkpoint> {
|
||||
return SubFlow.create(flowLogicClass).map { topLevelSubFlow ->
|
||||
Checkpoint(
|
||||
invocationContext = invocationContext,
|
||||
ourIdentity = ourIdentity,
|
||||
sessions = emptyMap(),
|
||||
subFlowStack = listOf(topLevelSubFlow),
|
||||
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
|
||||
errorState = ErrorState.Clean,
|
||||
numberOfSuspends = 0,
|
||||
deduplicationSeed = deduplicationSeed
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The state of a session.
|
||||
*/
|
||||
sealed class SessionState {
|
||||
|
||||
/**
|
||||
* We haven't yet sent the initialisation message
|
||||
*/
|
||||
data class Uninitiated(
|
||||
val party: Party,
|
||||
val initiatingSubFlow: SubFlow.Initiating
|
||||
) : SessionState()
|
||||
|
||||
/**
|
||||
* We have sent the initialisation message but have not yet received a confirmation.
|
||||
* @property rejectionError if non-null the initiation failed.
|
||||
*/
|
||||
data class Initiating(
|
||||
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
|
||||
val rejectionError: FlowError?
|
||||
) : SessionState()
|
||||
|
||||
/**
|
||||
* We have received a confirmation, the peer party and session id is resolved.
|
||||
* @property errors if not empty the session is in an errored state.
|
||||
*/
|
||||
data class Initiated(
|
||||
val peerParty: Party,
|
||||
val peerFlowInfo: FlowInfo,
|
||||
val receivedMessages: List<DataSessionMessage>,
|
||||
val initiatedState: InitiatedSessionState,
|
||||
val errors: List<FlowError>
|
||||
) : SessionState()
|
||||
}
|
||||
|
||||
typealias SessionMap = Map<SessionId, SessionState>
|
||||
|
||||
/**
|
||||
* Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest
|
||||
* of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future
|
||||
* [FlowInfo] requests.
|
||||
*/
|
||||
sealed class InitiatedSessionState {
|
||||
data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState()
|
||||
object Ended : InitiatedSessionState() { override fun toString() = "Ended" }
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the way the flow has started.
|
||||
*/
|
||||
sealed class FlowStart {
|
||||
/**
|
||||
* The flow was started explicitly e.g. through RPC or a scheduled state.
|
||||
*/
|
||||
object Explicit : FlowStart() { override fun toString() = "Explicit" }
|
||||
|
||||
/**
|
||||
* The flow was started implicitly as part of session initiation.
|
||||
*/
|
||||
data class Initiated(
|
||||
val peerSession: FlowSessionImpl,
|
||||
val initiatedSessionId: SessionId,
|
||||
val initiatingMessage: InitialSessionMessage,
|
||||
val senderCoreFlowVersion: Int?,
|
||||
val initiatedFlowInfo: FlowInfo
|
||||
) : FlowStart() { override fun toString() = "Initiated" }
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the user-space related state of the flow.
|
||||
*/
|
||||
sealed class FlowState {
|
||||
|
||||
/**
|
||||
* The flow's unstarted state. We should always be able to start a fresh flow fiber from this datastructure.
|
||||
*
|
||||
* @param flowStart How the flow was started.
|
||||
* @param frozenFlowLogic The serialized user-provided [FlowLogic].
|
||||
*/
|
||||
data class Unstarted(
|
||||
val flowStart: FlowStart,
|
||||
val frozenFlowLogic: SerializedBytes<FlowLogic<*>>
|
||||
) : FlowState() {
|
||||
override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash})"
|
||||
}
|
||||
|
||||
/**
|
||||
* The flow's started state, this means the user-code has suspended on an IO request.
|
||||
*
|
||||
* @param flowIORequest what IO request the flow has suspended on.
|
||||
* @param frozenFiber the serialized fiber itself.
|
||||
*/
|
||||
data class Started(
|
||||
val flowIORequest: FlowIORequest<*>,
|
||||
val frozenFiber: SerializedBytes<FlowStateMachineImpl<*>>
|
||||
) : FlowState() {
|
||||
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param errorId the ID of the error. This is generated once for the source error and is propagated to neighbour
|
||||
* sessions.
|
||||
* @param exception the exception itself. Note that this may not contain information about the source error depending
|
||||
* on whether the source error was a FlowException or otherwise.
|
||||
*/
|
||||
data class FlowError(val errorId: Long, val exception: Throwable)
|
||||
|
||||
/**
|
||||
* The flow's error state.
|
||||
*/
|
||||
sealed class ErrorState {
|
||||
abstract fun addErrors(newErrors: List<FlowError>): ErrorState
|
||||
|
||||
/**
|
||||
* The flow is in a clean state.
|
||||
*/
|
||||
object Clean : ErrorState() {
|
||||
override fun addErrors(newErrors: List<FlowError>): ErrorState {
|
||||
return Errored(newErrors, 0, false)
|
||||
}
|
||||
override fun toString() = "Clean"
|
||||
}
|
||||
|
||||
/**
|
||||
* The flow has dirtied because of an uncaught exception from user code or other error condition during a state
|
||||
* transition.
|
||||
* @param errors the list of errors. Multiple errors may be associated with the errored flow e.g. when multiple
|
||||
* sessions are errored and have been waited on.
|
||||
* @param propagatedIndex the index of the first error that hasn't yet been propagated.
|
||||
* @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the
|
||||
* sessions associated with the flow have been (or about to be) dirtied in counter-flows.
|
||||
*/
|
||||
data class Errored(
|
||||
val errors: List<FlowError>,
|
||||
val propagatedIndex: Int,
|
||||
val propagating: Boolean
|
||||
) : ErrorState() {
|
||||
override fun addErrors(newErrors: List<FlowError>): ErrorState {
|
||||
return copy(errors = errors + newErrors)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.utilities.Try
|
||||
|
||||
/**
|
||||
* A [SubFlow] contains metadata about a currently executing sub-flow. At any point the flow execution is
|
||||
* characterised with a stack of [SubFlow]s. This stack is used to determine the initiating-initiated flow mapping.
|
||||
*
|
||||
* Note that Initiat*ed*ness is an orthogonal property of the top-level subflow, so we don't store any information about
|
||||
* it here.
|
||||
*/
|
||||
sealed class SubFlow {
|
||||
abstract val flowClass: Class<out FlowLogic<*>>
|
||||
|
||||
/**
|
||||
* An inlined subflow.
|
||||
*/
|
||||
data class Inlined(override val flowClass: Class<FlowLogic<*>>) : SubFlow()
|
||||
|
||||
/**
|
||||
* An initiating subflow.
|
||||
* @param [flowClass] the concrete class of the subflow.
|
||||
* @param [classToInitiateWith] an ancestor class of [flowClass] with the [InitiatingFlow] annotation, to be sent
|
||||
* to the initiated side.
|
||||
* @param flowInfo the [FlowInfo] associated with the initiating flow.
|
||||
*/
|
||||
data class Initiating(
|
||||
override val flowClass: Class<FlowLogic<*>>,
|
||||
val classToInitiateWith: Class<in FlowLogic<*>>,
|
||||
val flowInfo: FlowInfo
|
||||
) : SubFlow()
|
||||
|
||||
companion object {
|
||||
fun create(flowClass: Class<FlowLogic<*>>): Try<SubFlow> {
|
||||
// Are we an InitiatingFlow?
|
||||
val initiatingAnnotations = getInitiatingFlowAnnotations(flowClass)
|
||||
return when (initiatingAnnotations.size) {
|
||||
0 -> {
|
||||
Try.Success(Inlined(flowClass))
|
||||
}
|
||||
1 -> {
|
||||
val initiatingAnnotation = initiatingAnnotations[0]
|
||||
val flowContext = FlowInfo(initiatingAnnotation.second.version, flowClass.appName)
|
||||
Try.Success(Initiating(flowClass, initiatingAnnotation.first, flowContext))
|
||||
}
|
||||
else -> {
|
||||
Try.Failure(IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated " +
|
||||
"once, however the following classes all have the annotation: " +
|
||||
"${initiatingAnnotations.map { it.first }}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun <C> getSuperClasses(clazz: Class<C>): List<Class<in C>> {
|
||||
var currentClass: Class<in C>? = clazz
|
||||
val result = ArrayList<Class<in C>>()
|
||||
while (currentClass != null) {
|
||||
result.add(currentClass)
|
||||
currentClass = currentClass.superclass
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
private fun getInitiatingFlowAnnotations(flowClass: Class<FlowLogic<*>>): List<Pair<Class<in FlowLogic<*>>, InitiatingFlow>> {
|
||||
return getSuperClasses(flowClass).mapNotNull { clazz ->
|
||||
val initiatingAnnotation = clazz.getDeclaredAnnotation(InitiatingFlow::class.java)
|
||||
initiatingAnnotation?.let { Pair(clazz, it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
|
||||
/**
|
||||
* An executor of state machine transitions. This is mostly a wrapper interface around an [ActionExecutor], but can be
|
||||
* used to create interceptors of transitions.
|
||||
*/
|
||||
interface TransitionExecutor {
|
||||
@Suspendable
|
||||
fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState>
|
||||
}
|
||||
|
||||
/**
|
||||
* An interceptor of a transition. These are currently explicitly hooked up in [SingleThreadedStateMachineManager].
|
||||
*/
|
||||
typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor
|
@ -0,0 +1,67 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.contextDatabase
|
||||
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* This [TransitionExecutor] runs the transition actions using the passed in [ActionExecutor] and manually dirties the
|
||||
* state on failure.
|
||||
*
|
||||
* If a failure happens when we're already transitioning into a errored state then the transition and the flow fiber is
|
||||
* completely aborted to avoid error loops.
|
||||
*/
|
||||
class TransitionExecutorImpl(
|
||||
val secureRandom: SecureRandom,
|
||||
val database: CordaPersistence
|
||||
) : TransitionExecutor {
|
||||
private companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
contextDatabase = database
|
||||
for (action in transition.actions) {
|
||||
try {
|
||||
actionExecutor.executeAction(fiber, action)
|
||||
} catch (exception: Throwable) {
|
||||
contextTransactionOrNull?.close()
|
||||
if (transition.newState.checkpoint.errorState is ErrorState.Errored) {
|
||||
// If we errored while transitioning to an error state then we cannot record the additional
|
||||
// error as that may result in an infinite loop, e.g. error propagation fails -> record error -> propagate fails again.
|
||||
// Instead we just keep around the old error state and wait for a new schedule, perhaps
|
||||
// triggered from a flow hospital
|
||||
log.error("Error while executing $action during transition to errored state, aborting transition", exception)
|
||||
return Pair(FlowContinuation.Abort, previousState.copy(isFlowResumed = false))
|
||||
} else {
|
||||
// Otherwise error the state manually keeping the old flow state and schedule a DoRemainingWork
|
||||
// to trigger error propagation
|
||||
log.error("Error while executing $action, erroring state", exception)
|
||||
val newState = previousState.copy(
|
||||
checkpoint = previousState.checkpoint.copy(
|
||||
errorState = previousState.checkpoint.errorState.addErrors(
|
||||
listOf(FlowError(secureRandom.nextLong(), exception))
|
||||
)
|
||||
),
|
||||
isFlowResumed = false
|
||||
)
|
||||
fiber.scheduleEvent(Event.DoRemainingWork)
|
||||
return Pair(FlowContinuation.ProcessEvents, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
return Pair(transition.continuation, transition.newState)
|
||||
}
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
/**
|
||||
* This interceptor records a trace of all of the flows' states and transitions. If the flow dirties it dumps the trace
|
||||
* transition to the logger.
|
||||
*/
|
||||
class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
private val records = ConcurrentHashMap<StateMachineRunId, ArrayList<TransitionDiagnosticRecord>>()
|
||||
|
||||
@Suspendable
|
||||
override fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
|
||||
val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation)
|
||||
val record = records.compute(fiber.id) { _, record ->
|
||||
(record ?: ArrayList()).apply { add(transitionRecord) }
|
||||
}
|
||||
|
||||
if (nextState.checkpoint.errorState is ErrorState.Errored) {
|
||||
log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}")
|
||||
for (error in nextState.checkpoint.errorState.errors) {
|
||||
log.warn("Flow ${fiber.id} error", error.exception)
|
||||
}
|
||||
}
|
||||
|
||||
if (nextState.isRemoved) {
|
||||
records.remove(fiber.id)
|
||||
}
|
||||
|
||||
return Pair(continuation, nextState)
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
/**
|
||||
* This interceptor checks whether a checkpointed fiber state can be deserialised in a separate thread.
|
||||
*/
|
||||
class FiberDeserializationCheckingInterceptor(
|
||||
val fiberDeserializationChecker: FiberDeserializationChecker,
|
||||
val delegate: TransitionExecutor
|
||||
) : TransitionExecutor {
|
||||
@Suspendable
|
||||
override fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
|
||||
val previousFlowState = previousState.checkpoint.flowState
|
||||
val nextFlowState = nextState.checkpoint.flowState
|
||||
if (nextFlowState is FlowState.Started) {
|
||||
if (previousFlowState !is FlowState.Started || previousFlowState.frozenFiber != nextFlowState.frozenFiber) {
|
||||
fiberDeserializationChecker.submitCheck(nextFlowState.frozenFiber)
|
||||
}
|
||||
}
|
||||
return Pair(continuation, nextState)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A fiber deserialisation checker thread. It checks the queued up serialised checkpoints to see if they can be
|
||||
* deserialised. This is only run in development mode to allow detecting of corrupt serialised checkpoints before they
|
||||
* are actually used.
|
||||
*/
|
||||
class FiberDeserializationChecker {
|
||||
companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
private sealed class Job {
|
||||
class Check(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) : Job()
|
||||
object Finish : Job()
|
||||
}
|
||||
|
||||
private var checkerThread: Thread? = null
|
||||
private val jobQueue = LinkedBlockingQueue<Job>()
|
||||
private var foundUnrestorableFibers: Boolean = false
|
||||
|
||||
fun start(checkpointSerializationContext: SerializationContext) {
|
||||
require(checkerThread == null)
|
||||
checkerThread = thread(name = "FiberDeserializationChecker") {
|
||||
while (true) {
|
||||
val job = jobQueue.take()
|
||||
when (job) {
|
||||
is Job.Check -> {
|
||||
try {
|
||||
job.serializedFiber.deserialize(context = checkpointSerializationContext)
|
||||
} catch (throwable: Throwable) {
|
||||
log.error("Encountered unrestorable checkpoint!", throwable)
|
||||
foundUnrestorableFibers = true
|
||||
}
|
||||
}
|
||||
Job.Finish -> {
|
||||
return@thread
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun submitCheck(serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
|
||||
jobQueue.add(Job.Check(serializedFiber))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if some unrestorable checkpoints were encountered, false otherwise
|
||||
*/
|
||||
fun stop(): Boolean {
|
||||
jobQueue.add(Job.Finish)
|
||||
checkerThread?.join()
|
||||
checkerThread = null
|
||||
return foundUnrestorableFibers
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
/**
|
||||
* This interceptor notifies the passed in [flowHospital] in case a flow went through a clean->errored or a errored->clean
|
||||
* transition.
|
||||
*/
|
||||
class HospitalisingInterceptor(
|
||||
private val flowHospital: FlowHospital,
|
||||
private val delegate: TransitionExecutor
|
||||
) : TransitionExecutor {
|
||||
private val hospitalisedFlows = ConcurrentHashMap<StateMachineRunId, FlowFiber>()
|
||||
|
||||
@Suspendable
|
||||
override fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
|
||||
when (nextState.checkpoint.errorState) {
|
||||
ErrorState.Clean -> {
|
||||
if (hospitalisedFlows.remove(fiber.id) != null) {
|
||||
flowHospital.flowCleaned(fiber)
|
||||
}
|
||||
}
|
||||
is ErrorState.Errored -> {
|
||||
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
|
||||
flowHospital.flowErrored(fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nextState.isRemoved) {
|
||||
hospitalisedFlows.remove(fiber.id)
|
||||
}
|
||||
return Pair(continuation, nextState)
|
||||
}
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
|
||||
class MetricInterceptor(val metrics: MetricRegistry, val delegate: TransitionExecutor): TransitionExecutor {
|
||||
@Suspendable
|
||||
override fun executeTransition(fiber: FlowFiber, previousState: StateMachineState, event: Event, transition: TransitionResult, actionExecutor: ActionExecutor): Pair<FlowContinuation, StateMachineState> {
|
||||
val metricActionInterceptor = MetricActionInterceptor(metrics, actionExecutor)
|
||||
return delegate.executeTransition(fiber, previousState, event, transition, metricActionInterceptor)
|
||||
}
|
||||
}
|
||||
|
||||
class MetricActionInterceptor(val metrics: MetricRegistry, val delegate: ActionExecutor): ActionExecutor {
|
||||
@Suspendable
|
||||
override fun executeAction(fiber: FlowFiber, action: Action) {
|
||||
val context = metrics.timer("Flows.Actions.${action.javaClass.simpleName}").time()
|
||||
delegate.executeAction(fiber, action)
|
||||
context.stop()
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.statemachine.*
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* This interceptor simply prints all state machine transitions. Useful for debugging.
|
||||
*/
|
||||
class PrintingInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
|
||||
companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun executeTransition(
|
||||
fiber: FlowFiber,
|
||||
previousState: StateMachineState,
|
||||
event: Event,
|
||||
transition: TransitionResult,
|
||||
actionExecutor: ActionExecutor
|
||||
): Pair<FlowContinuation, StateMachineState> {
|
||||
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
|
||||
val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation)
|
||||
log.info("Transition for flow ${fiber.id} $transitionRecord")
|
||||
return Pair(continuation, nextState)
|
||||
}
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
package net.corda.node.services.statemachine.interceptors
|
||||
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.transitions.FlowContinuation
|
||||
import net.corda.node.services.statemachine.Event
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
import net.corda.node.services.statemachine.transitions.TransitionResult
|
||||
import net.corda.node.utilities.ObjectDiffer
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* This is a diagnostic record that stores information about a state machine transition and provides pretty printing
|
||||
* by diffing the two states.
|
||||
*/
|
||||
data class TransitionDiagnosticRecord(
|
||||
val timestamp: Instant,
|
||||
val flowId: StateMachineRunId,
|
||||
val previousState: StateMachineState,
|
||||
val nextState: StateMachineState,
|
||||
val event: Event,
|
||||
val transition: TransitionResult,
|
||||
val continuation: FlowContinuation
|
||||
) {
|
||||
override fun toString(): String {
|
||||
val diffIntended = ObjectDiffer.diff(previousState, transition.newState)
|
||||
val diffNext = ObjectDiffer.diff(previousState, nextState)
|
||||
return (
|
||||
listOf(
|
||||
"",
|
||||
" --- Transition of flow $flowId ---",
|
||||
" Timestamp: $timestamp",
|
||||
" Event: $event",
|
||||
" Actions: ",
|
||||
" ${transition.actions.joinToString("\n ")}",
|
||||
" Continuation: ${transition.continuation}"
|
||||
) +
|
||||
if (diffIntended != diffNext) {
|
||||
listOf(
|
||||
" Diff between previous and intended state:",
|
||||
"${diffIntended?.toPaths()?.joinToString("")}"
|
||||
)
|
||||
} else {
|
||||
emptyList()
|
||||
} + listOf(
|
||||
|
||||
" Diff between previous and next state:",
|
||||
"${diffNext?.toPaths()?.joinToString("")}"
|
||||
)
|
||||
).joinToString("\n")
|
||||
}
|
||||
}
|
@ -0,0 +1,186 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.UnexpectedFlowEndException
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This transition handles incoming session messages. It handles the following cases:
|
||||
* - DataSessionMessage: these arrive to initiated and confirmed sessions and are expected to be received by the flow.
|
||||
* - ConfirmSessionMessage: these arrive as a response to an InitialSessionMessage and include information about the
|
||||
* counterparty flow's session ID as well as their [FlowInfo].
|
||||
* - ErrorSessionMessage: these arrive to initiated and confirmed sessions and put the corresponding session into an
|
||||
* "errored" state. This means that whenever that session is subsequently interacted with the error will be thrown
|
||||
* in the flow.
|
||||
* - RejectSessionMessage: these arrive as a response to an InitialSessionMessage when the initiation failed. It
|
||||
* behaves similarly to ErrorSessionMessage aside from the type of exceptions stored/raised.
|
||||
* - EndSessionMessage: these are sent when the counterparty flow has finished. They put the corresponding session into
|
||||
* an "ended" state. This means that subsequent sends on this session will fail, and receives will start failing
|
||||
* after the buffer of already received messages is drained.
|
||||
*/
|
||||
class DeliverSessionMessageTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState,
|
||||
val event: Event.DeliverSessionMessage
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
return builder {
|
||||
// Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know
|
||||
// about the message. Note that in case of an error during deliver this message *will be acked*.
|
||||
// For example if the session corresponding to the message is not found the message is still acked to free
|
||||
// up the broker but the flow will error.
|
||||
currentState = currentState.copy(
|
||||
pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler
|
||||
)
|
||||
// Check whether we have a session corresponding to the message.
|
||||
val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId]
|
||||
if (existingSession == null) {
|
||||
freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId))
|
||||
} else {
|
||||
val payload = event.sessionMessage.payload
|
||||
// Dispatch based on what kind of message it is.
|
||||
val _exhaustive = when (payload) {
|
||||
is ConfirmSessionMessage -> confirmMessageTransition(existingSession, payload)
|
||||
is DataSessionMessage -> dataMessageTransition(existingSession, payload)
|
||||
is ErrorSessionMessage -> errorMessageTransition(existingSession, payload)
|
||||
is RejectSessionMessage -> rejectMessageTransition(existingSession, payload)
|
||||
is EndSessionMessage -> endMessageTransition()
|
||||
}
|
||||
}
|
||||
if (!isErrored()) {
|
||||
persistCheckpoint()
|
||||
}
|
||||
// Schedule a DoRemainingWork to check whether the flow needs to be woken up.
|
||||
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) {
|
||||
// We received a confirmation message. The corresponding session state must be Initiating.
|
||||
when (sessionState) {
|
||||
is SessionState.Initiating -> {
|
||||
// Create the new session state that is now Initiated.
|
||||
val initiatedSession = SessionState.Initiated(
|
||||
peerParty = event.sender,
|
||||
peerFlowInfo = message.initiatedFlowInfo,
|
||||
receivedMessages = emptyList(),
|
||||
initiatedState = InitiatedSessionState.Live(message.initiatedSessionId),
|
||||
errors = emptyList()
|
||||
)
|
||||
val newCheckpoint = currentState.checkpoint.copy(
|
||||
sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession)
|
||||
)
|
||||
// Send messages that were buffered pending confirmation of session.
|
||||
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
|
||||
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
|
||||
Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId)
|
||||
}
|
||||
actions.addAll(sendActions)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
|
||||
// We received a data message. The corresponding session must be Initiated.
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
// Buffer the message in the session's receivedMessages buffer.
|
||||
val newSessionState = sessionState.copy(
|
||||
receivedMessages = sessionState.receivedMessages + message
|
||||
)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState)
|
||||
)
|
||||
)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
|
||||
val exception: Throwable = if (payload.flowException == null) {
|
||||
UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId)
|
||||
} else {
|
||||
payload.flowException.originalErrorId = payload.errorId
|
||||
payload.flowException
|
||||
}
|
||||
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val flowError = FlowError(payload.errorId, exception)
|
||||
val newSessionState = sessionState.copy(errors = sessionState.errors + flowError)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.copy(
|
||||
sessions = checkpoint.sessions + (sessionId to newSessionState)
|
||||
)
|
||||
)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.rejectMessageTransition(sessionState: SessionState, payload: RejectSessionMessage) {
|
||||
val exception = UnexpectedFlowEndException(payload.message, cause = null, originalErrorId = payload.errorId)
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiating -> {
|
||||
if (sessionState.rejectionError != null) {
|
||||
// Double reject
|
||||
freshErrorTransition(UnexpectedEventInState())
|
||||
} else {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val flowError = FlowError(payload.errorId, exception)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.copy(
|
||||
sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError))
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.persistCheckpoint() {
|
||||
// We persist the message as soon as it arrives.
|
||||
actions.addAll(arrayOf(
|
||||
Action.CreateTransaction,
|
||||
Action.PersistCheckpoint(context.id, currentState.checkpoint),
|
||||
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
|
||||
))
|
||||
currentState = currentState.copy(
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isAnyCheckpointPersisted = true
|
||||
)
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.endMessageTransition() {
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val sessions = currentState.checkpoint.sessions
|
||||
val sessionState = sessions[sessionId]
|
||||
if (sessionState == null) {
|
||||
return freshErrorTransition(CannotFindSessionException(sessionId))
|
||||
}
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
sessions = sessions + (sessionId to newSessionState)
|
||||
)
|
||||
)
|
||||
}
|
||||
else -> {
|
||||
freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This transition checks the current state of the flow and determines whether anything needs to be done.
|
||||
*/
|
||||
class DoRemainingWorkTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
val checkpoint = startingState.checkpoint
|
||||
// If the flow is removed or has been resumed don't do work.
|
||||
if (startingState.isFlowResumed || startingState.isRemoved) {
|
||||
return TransitionResult(startingState)
|
||||
}
|
||||
// Check whether the flow is errored
|
||||
return when (checkpoint.errorState) {
|
||||
is ErrorState.Clean -> cleanTransition()
|
||||
is ErrorState.Errored -> erroredTransition(checkpoint.errorState)
|
||||
}
|
||||
}
|
||||
|
||||
// If the flow is clean check the FlowState
|
||||
private fun cleanTransition(): TransitionResult {
|
||||
val checkpoint = startingState.checkpoint
|
||||
return when (checkpoint.flowState) {
|
||||
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition()
|
||||
is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition()
|
||||
}
|
||||
}
|
||||
|
||||
private fun erroredTransition(errorState: ErrorState.Errored): TransitionResult {
|
||||
return ErrorFlowTransition(context, startingState, errorState).transition()
|
||||
}
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This transition defines what should happen when a flow has errored.
|
||||
*
|
||||
* In general there are two flow-level error conditions:
|
||||
*
|
||||
* - Internal exceptions. These may arise due to problems in the flow framework or errors during state machine
|
||||
* transitions e.g. network or database failure.
|
||||
* - User-raised exceptions. These are exceptions that are (re)raised in user code, allowing the user to catch them.
|
||||
* These may come from illegal flow API calls, and FlowExceptions or other counterparty failures that are re-raised
|
||||
* when the flow tries to use the corresponding sessions.
|
||||
*
|
||||
* Both internal exceptions and uncaught user-raised exceptions cause the flow to be errored. This flags the flow as
|
||||
* unable to be resumed. When a flow is in this state an external source (e.g. Flow hospital) may decide to
|
||||
*
|
||||
* 1. Retry it (not implemented yet). This throws away the errored state and re-tries from the last clean checkpoint.
|
||||
* 2. Start error propagation. This seals the flow as errored permanently and propagates the associated error(s) to
|
||||
* all live sessions. This causes these sessions to errored on the other side, which may in turn cause the
|
||||
* counter-flows themselves to errored.
|
||||
*
|
||||
* See [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor] for how to detect flow errors.
|
||||
*
|
||||
* Note that in general we handle multiple errors at a time as several error conditions may arise at the same time and
|
||||
* new errors may arise while the flow is in the errored state already.
|
||||
*/
|
||||
class ErrorFlowTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState,
|
||||
private val errorState: ErrorState.Errored
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
val allErrors: List<FlowError> = errorState.errors
|
||||
val remainingErrorsToPropagate: List<FlowError> = allErrors.subList(errorState.propagatedIndex, allErrors.size)
|
||||
val errorMessages: List<ErrorSessionMessage> = remainingErrorsToPropagate.map(this::createErrorMessageFromError)
|
||||
|
||||
return builder {
|
||||
// If we're errored and propagating do the actual propagation and update the index.
|
||||
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
|
||||
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages)
|
||||
val newCheckpoint = startingState.checkpoint.copy(
|
||||
errorState = errorState.copy(propagatedIndex = allErrors.size),
|
||||
sessions = newSessions
|
||||
)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions))
|
||||
}
|
||||
|
||||
// If we're errored but not propagating keep processing events.
|
||||
if (remainingErrorsToPropagate.isNotEmpty() && !errorState.propagating) {
|
||||
return@builder FlowContinuation.ProcessEvents
|
||||
}
|
||||
|
||||
// If we haven't been removed yet remove the flow.
|
||||
if (!currentState.isRemoved) {
|
||||
actions.add(Action.CreateTransaction)
|
||||
if (currentState.isAnyCheckpointPersisted) {
|
||||
actions.add(Action.RemoveCheckpoint(context.id))
|
||||
}
|
||||
actions.addAll(arrayOf(
|
||||
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
|
||||
Action.ReleaseSoftLocks(context.id.uuid),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
|
||||
Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys)
|
||||
))
|
||||
|
||||
currentState = currentState.copy(
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isRemoved = true
|
||||
)
|
||||
|
||||
val removalReason = FlowRemovalReason.ErrorFinish(allErrors)
|
||||
actions.add(Action.RemoveFlow(context.id, removalReason, currentState))
|
||||
FlowContinuation.Abort
|
||||
} else {
|
||||
// Otherwise keep processing events. This branch happens when there are some outstanding initiating
|
||||
// sessions that prevent the removal of the flow.
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createErrorMessageFromError(error: FlowError): ErrorSessionMessage {
|
||||
val exception = error.exception
|
||||
// If the exception doesn't contain an originalErrorId that means it's a fresh FlowException that should
|
||||
// propagate to the neighbouring flows. If it has the ID filled in that means it's a rethrown FlowException and
|
||||
// shouldn't be propagated.
|
||||
return if (exception is FlowException && exception.originalErrorId == null) {
|
||||
ErrorSessionMessage(flowException = exception, errorId = error.errorId)
|
||||
} else {
|
||||
ErrorSessionMessage(flowException = null, errorId = error.errorId)
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer error messages in Initiating sessions, return the initialised ones.
|
||||
private fun bufferErrorMessagesInInitiatingSessions(
|
||||
sessions: Map<SessionId, SessionState>,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
|
||||
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
|
||||
val errorMessagesWithDeduplication = errorMessages.map {
|
||||
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
|
||||
}
|
||||
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
}
|
||||
val initiatedSessions = sessions.values.mapNotNull { session ->
|
||||
if (session is SessionState.Initiated && session.errors.isEmpty()) {
|
||||
session
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
return Pair(initiatedSessions, newSessions)
|
||||
}
|
||||
}
|
@ -0,0 +1,410 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.flows.UnexpectedFlowEndException
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.toNonEmptySet
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request
|
||||
* is persisted (unless checkpoint was skipped) and the user-space DB transaction is commited.
|
||||
*
|
||||
* Before this transition we either did a checkpoint or the checkpoint was restored from the database.
|
||||
*/
|
||||
class StartedFlowTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState,
|
||||
val started: FlowState.Started
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
val flowIORequest = started.flowIORequest
|
||||
val checkpoint = startingState.checkpoint
|
||||
val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint)
|
||||
if (errorsToThrow.isNotEmpty()) {
|
||||
return TransitionResult(
|
||||
newState = startingState.copy(isFlowResumed = true),
|
||||
// throw the first exception. TODO should this aggregate all of them somehow?
|
||||
actions = listOf(Action.CreateTransaction),
|
||||
continuation = FlowContinuation.Throw(errorsToThrow[0])
|
||||
)
|
||||
}
|
||||
return when (flowIORequest) {
|
||||
is FlowIORequest.Send -> sendTransition(flowIORequest)
|
||||
is FlowIORequest.Receive -> receiveTransition(flowIORequest)
|
||||
is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest)
|
||||
is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest)
|
||||
is FlowIORequest.Sleep -> sleepTransition(flowIORequest)
|
||||
is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest)
|
||||
is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition()
|
||||
is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest)
|
||||
}
|
||||
}
|
||||
|
||||
private fun waitForSessionConfirmationsTransition(): TransitionResult {
|
||||
return builder {
|
||||
if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) {
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
resumeFlowLogic(Unit)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun getFlowInfoTransition(flowIORequest: FlowIORequest.GetFlowInfo): TransitionResult {
|
||||
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
|
||||
for (session in flowIORequest.sessions) {
|
||||
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
|
||||
}
|
||||
return builder {
|
||||
// Initialise uninitialised sessions in order to receive the associated FlowInfo. Some or all sessions may
|
||||
// not be initialised yet.
|
||||
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
|
||||
val flowInfoMap = getFlowInfoFromSessions(sessionIdToSession)
|
||||
if (flowInfoMap == null) {
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
resumeFlowLogic(flowInfoMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.getFlowInfoFromSessions(sessionIdToSession: Map<SessionId, FlowSessionImpl>): Map<FlowSession, FlowInfo>? {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val resultMap = LinkedHashMap<FlowSession, FlowInfo>()
|
||||
for ((sessionId, session) in sessionIdToSession) {
|
||||
val sessionState = checkpoint.sessions[sessionId]
|
||||
if (sessionState is SessionState.Initiated) {
|
||||
resultMap[session] = sessionState.peerFlowInfo
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
}
|
||||
return resultMap
|
||||
}
|
||||
|
||||
private fun sleepTransition(flowIORequest: FlowIORequest.Sleep): TransitionResult {
|
||||
return builder {
|
||||
actions.add(Action.SleepUntil(flowIORequest.wakeUpAfter))
|
||||
resumeFlowLogic(Unit)
|
||||
}
|
||||
}
|
||||
|
||||
private fun waitForLedgerCommitTransition(flowIORequest: FlowIORequest.WaitForLedgerCommit): TransitionResult {
|
||||
return if (!startingState.isTransactionTracked) {
|
||||
TransitionResult(
|
||||
newState = startingState.copy(isTransactionTracked = true),
|
||||
actions = listOf(
|
||||
Action.CreateTransaction,
|
||||
Action.TrackTransaction(flowIORequest.hash),
|
||||
Action.CommitTransaction
|
||||
)
|
||||
)
|
||||
} else {
|
||||
TransitionResult(startingState)
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendAndReceiveTransition(flowIORequest: FlowIORequest.SendAndReceive): TransitionResult {
|
||||
val sessionIdToMessage = LinkedHashMap<SessionId, SerializedBytes<Any>>()
|
||||
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
|
||||
for ((session, message) in flowIORequest.sessionToMessage) {
|
||||
val sessionId = (session as FlowSessionImpl).sourceSessionId
|
||||
sessionIdToMessage[sessionId] = message
|
||||
sessionIdToSession[sessionId] = session
|
||||
}
|
||||
return builder {
|
||||
sendToSessionsTransition(sessionIdToMessage)
|
||||
if (isErrored()) {
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
|
||||
if (receivedMap == null) {
|
||||
// We don't yet have the messages, change the suspension to be on Receive
|
||||
val newIoRequest = FlowIORequest.Receive(flowIORequest.sessionToMessage.keys.toNonEmptySet())
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
flowState = FlowState.Started(newIoRequest, started.frozenFiber)
|
||||
)
|
||||
)
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
resumeFlowLogic(receivedMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult {
|
||||
return builder {
|
||||
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
|
||||
for (session in flowIORequest.sessions) {
|
||||
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
|
||||
}
|
||||
// send initialises to uninitialised sessions
|
||||
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
|
||||
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
|
||||
if (receivedMap == null) {
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
resumeFlowLogic(receivedMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.receiveFromSessionsTransition(
|
||||
sourceSessionIdToSessionMap: Map<SessionId, FlowSessionImpl>
|
||||
): Map<FlowSession, SerializedBytes<Any>>? {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null
|
||||
val resultMap = LinkedHashMap<FlowSession, SerializedBytes<Any>>()
|
||||
for ((sessionId, message) in pollResult.messages) {
|
||||
val session = sourceSessionIdToSessionMap[sessionId]!!
|
||||
resultMap[session] = message
|
||||
}
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap)
|
||||
)
|
||||
return resultMap
|
||||
}
|
||||
|
||||
data class PollResult(
|
||||
val messages: Map<SessionId, SerializedBytes<Any>>,
|
||||
val newSessionMap: SessionMap
|
||||
)
|
||||
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
|
||||
val newSessionMessages = LinkedHashMap(sessions)
|
||||
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
|
||||
var someNotFound = false
|
||||
for (sessionId in sessionIds) {
|
||||
val sessionState = sessions[sessionId]
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val messages = sessionState.receivedMessages
|
||||
if (messages.isEmpty()) {
|
||||
someNotFound = true
|
||||
} else {
|
||||
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
|
||||
resultMessages[sessionId] = messages[0].payload
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
someNotFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return if (someNotFound) {
|
||||
return null
|
||||
} else {
|
||||
PollResult(resultMessages, newSessionMessages)
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
|
||||
val checkpoint = startingState.checkpoint
|
||||
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.sessions)
|
||||
var index = 0
|
||||
for (sourceSessionId in sourceSessions) {
|
||||
val sessionState = checkpoint.sessions[sourceSessionId]
|
||||
if (sessionState == null) {
|
||||
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
|
||||
}
|
||||
if (sessionState !is SessionState.Uninitiated) {
|
||||
continue
|
||||
}
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
|
||||
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null)
|
||||
actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId))
|
||||
newSessions[sourceSessionId] = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null
|
||||
)
|
||||
}
|
||||
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
|
||||
}
|
||||
|
||||
private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult {
|
||||
return builder {
|
||||
val sessionIdToMessage = flowIORequest.sessionToMessage.mapKeys {
|
||||
sessionToSessionId(it.key)
|
||||
}
|
||||
sendToSessionsTransition(sessionIdToMessage)
|
||||
if (isErrored()) {
|
||||
FlowContinuation.ProcessEvents
|
||||
} else {
|
||||
resumeFlowLogic(Unit)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
|
||||
val checkpoint = startingState.checkpoint
|
||||
val newSessions = LinkedHashMap(checkpoint.sessions)
|
||||
var index = 0
|
||||
for ((sourceSessionId, message) in sourceSessionIdToMessage) {
|
||||
val existingSessionState = checkpoint.sessions[sourceSessionId]
|
||||
if (existingSessionState == null) {
|
||||
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
|
||||
} else {
|
||||
val sessionMessage = DataSessionMessage(message)
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
|
||||
val _exhaustive = when (existingSessionState) {
|
||||
is SessionState.Uninitiated -> {
|
||||
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message)
|
||||
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId))
|
||||
newSessions[sourceSessionId] = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null
|
||||
)
|
||||
Unit
|
||||
}
|
||||
is SessionState.Initiating -> {
|
||||
// We're initiating this session, buffer the message
|
||||
val newBufferedMessages = existingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
|
||||
newSessions[sourceSessionId] = existingSessionState.copy(bufferedMessages = newBufferedMessages)
|
||||
}
|
||||
is SessionState.Initiated -> {
|
||||
when (existingSessionState.initiatedState) {
|
||||
is InitiatedSessionState.Live -> {
|
||||
val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
|
||||
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId))
|
||||
Unit
|
||||
}
|
||||
InitiatedSessionState.Ended -> {
|
||||
return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
|
||||
}
|
||||
|
||||
private fun sessionToSessionId(session: FlowSession): SessionId {
|
||||
return (session as FlowSessionImpl).sourceSessionId
|
||||
}
|
||||
|
||||
private fun collectErroredSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
|
||||
return sessionIds.flatMap { sessionId ->
|
||||
val sessionState = checkpoint.sessions[sessionId]!!
|
||||
when (sessionState) {
|
||||
is SessionState.Uninitiated -> emptyList()
|
||||
is SessionState.Initiating -> {
|
||||
if (sessionState.rejectionError == null) {
|
||||
emptyList()
|
||||
} else {
|
||||
listOf(sessionState.rejectionError.exception)
|
||||
}
|
||||
}
|
||||
is SessionState.Initiated -> sessionState.errors.map(FlowError::exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> {
|
||||
return checkpoint.sessions.values.mapNotNull { sessionState ->
|
||||
(sessionState as? SessionState.Initiating)?.rejectionError?.exception
|
||||
}
|
||||
}
|
||||
|
||||
private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
|
||||
return sessionIds.mapNotNull { sessionId ->
|
||||
val sessionState = checkpoint.sessions[sessionId]!!
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
if (sessionState.initiatedState is InitiatedSessionState.Ended) {
|
||||
UnexpectedFlowEndException(
|
||||
"Tried to access ended session $sessionId",
|
||||
cause = null,
|
||||
originalErrorId = context.secureRandom.nextLong()
|
||||
)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun collectEndedEmptySessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
|
||||
return sessionIds.mapNotNull { sessionId ->
|
||||
val sessionState = checkpoint.sessions[sessionId]!!
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
if (sessionState.initiatedState is InitiatedSessionState.Ended &&
|
||||
sessionState.receivedMessages.isEmpty()) {
|
||||
UnexpectedFlowEndException(
|
||||
"Tried to access ended session $sessionId with empty buffer",
|
||||
cause = null,
|
||||
originalErrorId = context.secureRandom.nextLong()
|
||||
)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List<Throwable> {
|
||||
return when (flowIORequest) {
|
||||
is FlowIORequest.Send -> {
|
||||
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
|
||||
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
|
||||
}
|
||||
is FlowIORequest.Receive -> {
|
||||
val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
|
||||
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint)
|
||||
}
|
||||
is FlowIORequest.SendAndReceive -> {
|
||||
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
|
||||
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
|
||||
}
|
||||
is FlowIORequest.WaitForLedgerCommit -> {
|
||||
collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint)
|
||||
}
|
||||
is FlowIORequest.GetFlowInfo -> {
|
||||
collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint)
|
||||
}
|
||||
is FlowIORequest.Sleep -> {
|
||||
emptyList()
|
||||
}
|
||||
is FlowIORequest.WaitForSessionConfirmations -> {
|
||||
collectErroredInitiatingSessionErrors(checkpoint)
|
||||
}
|
||||
is FlowIORequest.ExecuteAsyncOperation<*> -> {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createInitialSessionMessage(
|
||||
initiatingSubFlow: SubFlow.Initiating,
|
||||
sourceSessionId: SessionId,
|
||||
payload: SerializedBytes<Any>?
|
||||
): InitialSessionMessage {
|
||||
return InitialSessionMessage(
|
||||
initiatorSessionId = sourceSessionId,
|
||||
// We add additional entropy to add to the initiated side's deduplication seed.
|
||||
initiationEntropy = context.secureRandom.nextLong(),
|
||||
initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name,
|
||||
flowVersion = initiatingSubFlow.flowInfo.flowVersion,
|
||||
appName = initiatingSubFlow.flowInfo.appName,
|
||||
firstPayload = payload
|
||||
)
|
||||
}
|
||||
|
||||
private fun executeAsyncOperation(flowIORequest: FlowIORequest.ExecuteAsyncOperation<*>): TransitionResult {
|
||||
return builder {
|
||||
actions.add(Action.ExecuteAsyncOperation(flowIORequest.operation))
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.node.services.statemachine.*
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* @property eventQueueSize the size of a flow's event queue. If the queue gets full the thread scheduling the event
|
||||
* will block. An example scenario would be if the flow is waiting for a lot of messages at once, but is slow at
|
||||
* processing each.
|
||||
*/
|
||||
data class StateMachineConfiguration(
|
||||
val eventQueueSize: Int
|
||||
) {
|
||||
companion object {
|
||||
val default = StateMachineConfiguration(
|
||||
eventQueueSize = 16
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class StateMachine(
|
||||
val id: StateMachineRunId,
|
||||
val configuration: StateMachineConfiguration,
|
||||
val secureRandom: SecureRandom
|
||||
) {
|
||||
fun transition(event: Event, state: StateMachineState): TransitionResult {
|
||||
return TopLevelTransition(TransitionContext(id, configuration, secureRandom), state, event).transition()
|
||||
}
|
||||
}
|
@ -0,0 +1,233 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This is the top level event-handling transition function capable of handling any [Event].
|
||||
*
|
||||
* It is a *pure* function taking a state machine state and an event, returning the next state along with a list of IO
|
||||
* actions to execute.
|
||||
*/
|
||||
class TopLevelTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState,
|
||||
val event: Event
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
return when (event) {
|
||||
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
|
||||
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
|
||||
is Event.Error -> errorTransition(event)
|
||||
is Event.TransactionCommitted -> transactionCommittedTransition(event)
|
||||
is Event.SoftShutdown -> softShutdownTransition()
|
||||
is Event.StartErrorPropagation -> startErrorPropagationTransition()
|
||||
is Event.EnterSubFlow -> enterSubFlowTransition(event)
|
||||
is Event.LeaveSubFlow -> leaveSubFlowTransition()
|
||||
is Event.Suspend -> suspendTransition(event)
|
||||
is Event.FlowFinish -> flowFinishTransition(event)
|
||||
is Event.InitiateFlow -> initiateFlowTransition(event)
|
||||
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
|
||||
}
|
||||
}
|
||||
|
||||
private fun errorTransition(event: Event.Error): TransitionResult {
|
||||
return builder {
|
||||
freshErrorTransition(event.exception)
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun transactionCommittedTransition(event: Event.TransactionCommitted): TransitionResult {
|
||||
return builder {
|
||||
val checkpoint = currentState.checkpoint
|
||||
if (currentState.isTransactionTracked &&
|
||||
checkpoint.flowState is FlowState.Started &&
|
||||
checkpoint.flowState.flowIORequest is FlowIORequest.WaitForLedgerCommit &&
|
||||
checkpoint.flowState.flowIORequest.hash == event.transaction.id) {
|
||||
currentState = currentState.copy(isTransactionTracked = false)
|
||||
if (isErrored()) {
|
||||
return@builder FlowContinuation.ProcessEvents
|
||||
}
|
||||
resumeFlowLogic(event.transaction)
|
||||
} else {
|
||||
freshErrorTransition(UnexpectedEventInState())
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun softShutdownTransition(): TransitionResult {
|
||||
val lastState = startingState.copy(isRemoved = true)
|
||||
return TransitionResult(
|
||||
newState = lastState,
|
||||
actions = listOf(
|
||||
Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys),
|
||||
Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState)
|
||||
),
|
||||
continuation = FlowContinuation.Abort
|
||||
)
|
||||
}
|
||||
|
||||
private fun startErrorPropagationTransition(): TransitionResult {
|
||||
return builder {
|
||||
val errorState = currentState.checkpoint.errorState
|
||||
when (errorState) {
|
||||
ErrorState.Clean -> freshErrorTransition(UnexpectedEventInState())
|
||||
is ErrorState.Errored -> {
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
errorState = errorState.copy(propagating = true)
|
||||
)
|
||||
)
|
||||
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
|
||||
}
|
||||
}
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun enterSubFlowTransition(event: Event.EnterSubFlow): TransitionResult {
|
||||
return builder {
|
||||
val subFlow = SubFlow.create(event.subFlowClass)
|
||||
when (subFlow) {
|
||||
is Try.Success -> {
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value
|
||||
)
|
||||
)
|
||||
}
|
||||
is Try.Failure -> {
|
||||
freshErrorTransition(subFlow.exception)
|
||||
}
|
||||
}
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun leaveSubFlowTransition(): TransitionResult {
|
||||
return builder {
|
||||
val checkpoint = currentState.checkpoint
|
||||
if (checkpoint.subFlowStack.isEmpty()) {
|
||||
freshErrorTransition(UnexpectedEventInState())
|
||||
} else {
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.copy(
|
||||
subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList()
|
||||
)
|
||||
)
|
||||
}
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun suspendTransition(event: Event.Suspend): TransitionResult {
|
||||
return builder {
|
||||
val newCheckpoint = currentState.checkpoint.copy(
|
||||
flowState = FlowState.Started(event.ioRequest, event.fiber),
|
||||
numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1
|
||||
)
|
||||
actions.addAll(arrayOf(
|
||||
Action.PersistCheckpoint(context.id, newCheckpoint),
|
||||
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
|
||||
Action.ScheduleEvent(Event.DoRemainingWork)
|
||||
))
|
||||
currentState = currentState.copy(
|
||||
checkpoint = newCheckpoint,
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isFlowResumed = false,
|
||||
isAnyCheckpointPersisted = true
|
||||
)
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
|
||||
private fun flowFinishTransition(event: Event.FlowFinish): TransitionResult {
|
||||
return builder {
|
||||
val checkpoint = currentState.checkpoint
|
||||
when (checkpoint.errorState) {
|
||||
ErrorState.Clean -> {
|
||||
val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.copy(
|
||||
numberOfSuspends = checkpoint.numberOfSuspends + 1
|
||||
),
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isFlowResumed = false,
|
||||
isRemoved = true
|
||||
)
|
||||
val allSourceSessionIds = checkpoint.sessions.keys
|
||||
if (currentState.isAnyCheckpointPersisted) {
|
||||
actions.add(Action.RemoveCheckpoint(context.id))
|
||||
}
|
||||
actions.addAll(arrayOf(
|
||||
Action.PersistDeduplicationFacts(pendingDeduplicationHandlers),
|
||||
Action.ReleaseSoftLocks(event.softLocksId),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(pendingDeduplicationHandlers),
|
||||
Action.RemoveSessionBindings(allSourceSessionIds),
|
||||
Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState)
|
||||
))
|
||||
sendEndMessages()
|
||||
// Resume to end fiber
|
||||
FlowContinuation.Resume(null)
|
||||
}
|
||||
is ErrorState.Errored -> {
|
||||
currentState = currentState.copy(isFlowResumed = false)
|
||||
actions.add(Action.RollbackTransaction)
|
||||
FlowContinuation.ProcessEvents
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.sendEndMessages() {
|
||||
val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state ->
|
||||
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
|
||||
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
|
||||
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index)
|
||||
Action.SendExisting(state.peerParty, message, deduplicationId)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}.filterNotNull()
|
||||
actions.addAll(sendEndMessageActions)
|
||||
}
|
||||
|
||||
private fun initiateFlowTransition(event: Event.InitiateFlow): TransitionResult {
|
||||
return builder {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val initiatingSubFlow = getClosestAncestorInitiatingSubFlow(checkpoint)
|
||||
if (initiatingSubFlow == null) {
|
||||
freshErrorTransition(IllegalStateException("Tried to initiate in a flow not annotated with @${InitiatingFlow::class.java.simpleName}"))
|
||||
return@builder FlowContinuation.ProcessEvents
|
||||
}
|
||||
val sourceSessionId = SessionId.createRandom(context.secureRandom)
|
||||
val sessionImpl = FlowSessionImpl(event.party, sourceSessionId)
|
||||
val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow))
|
||||
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
|
||||
actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
|
||||
FlowContinuation.Resume(sessionImpl)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? {
|
||||
for (subFlow in checkpoint.subFlowStack.asReversed()) {
|
||||
if (subFlow is SubFlow.Initiating) {
|
||||
return subFlow
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private fun asyncOperationCompletionTransition(event: Event.AsyncOperationCompletion): TransitionResult {
|
||||
return builder {
|
||||
resumeFlowLogic(event.returnValue)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* An interface used to separate out different parts of the state machine transition function.
|
||||
*/
|
||||
interface Transition {
|
||||
/** The context of the transition. */
|
||||
val context: TransitionContext
|
||||
/** The state the transition is starting in. */
|
||||
val startingState: StateMachineState
|
||||
/** The (almost) pure transition function. The only side-effect we allow is random number generation. */
|
||||
fun transition(): TransitionResult
|
||||
|
||||
/**
|
||||
* A helper
|
||||
*/
|
||||
fun builder(build: TransitionBuilder.() -> FlowContinuation): TransitionResult {
|
||||
val builder = TransitionBuilder(context, startingState)
|
||||
val continuation = build(builder)
|
||||
return TransitionResult(builder.currentState, builder.actions, continuation)
|
||||
}
|
||||
}
|
||||
|
||||
class TransitionContext(
|
||||
val id: StateMachineRunId,
|
||||
val configuration: StateMachineConfiguration,
|
||||
val secureRandom: SecureRandom
|
||||
)
|
@ -0,0 +1,74 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.IdentifiableException
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
// This is a file defining some common utilities for creating state machine transitions.
|
||||
|
||||
/**
|
||||
* A builder that helps creating [Transition]s. This allows for a more imperative style of specifying the transition.
|
||||
*/
|
||||
class TransitionBuilder(val context: TransitionContext, initialState: StateMachineState) {
|
||||
/** The current state machine state of the builder */
|
||||
var currentState = initialState
|
||||
/** The list of actions to execute */
|
||||
val actions = ArrayList<Action>()
|
||||
|
||||
/** Check if [currentState] state is errored */
|
||||
fun isErrored(): Boolean = currentState.checkpoint.errorState is ErrorState.Errored
|
||||
|
||||
/**
|
||||
* Transition the builder into an error state because of a fresh error that happened.
|
||||
* Existing actions and the current state are thrown away, and the initial state is dirtied.
|
||||
*
|
||||
* @param error the error.
|
||||
*/
|
||||
fun freshErrorTransition(error: Throwable) {
|
||||
val flowError = FlowError(
|
||||
errorId = (error as? IdentifiableException)?.errorId ?: context.secureRandom.nextLong(),
|
||||
exception = error
|
||||
)
|
||||
errorTransition(flowError)
|
||||
}
|
||||
|
||||
/**
|
||||
* Transition the builder into an error state because of a list of errors that happened.
|
||||
* Existing actions and the current state are thrown away, and the initial state is dirtied.
|
||||
*
|
||||
* @param error the error.
|
||||
*/
|
||||
fun errorsTransition(errors: List<FlowError>) {
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
errorState = currentState.checkpoint.errorState.addErrors(errors)
|
||||
),
|
||||
isFlowResumed = false
|
||||
)
|
||||
actions.clear()
|
||||
actions.addAll(arrayOf(
|
||||
Action.RollbackTransaction,
|
||||
Action.ScheduleEvent(Event.DoRemainingWork)
|
||||
))
|
||||
}
|
||||
|
||||
/**
|
||||
* Transition the builder into an error state because of a non-fresh error has happened.
|
||||
* Existing actions and the current state are thrown away, and the initial state is dirtied.
|
||||
*
|
||||
* @param error the error.
|
||||
*/
|
||||
fun errorTransition(error: FlowError) {
|
||||
errorsTransition(listOf(error))
|
||||
}
|
||||
|
||||
fun resumeFlowLogic(result: Any?): FlowContinuation {
|
||||
actions.add(Action.CreateTransaction)
|
||||
currentState = currentState.copy(isFlowResumed = true)
|
||||
return FlowContinuation.Resume(result)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
|
||||
class UnexpectedEventInState : IllegalStateException("Unexpected event")
|
@ -0,0 +1,46 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
|
||||
/**
|
||||
* A datastructure capturing the intended new state of the flow, the actions to be executed as part of the transition
|
||||
* and a [FlowContinuation].
|
||||
*
|
||||
* Read this datastructure as an instruction to the state machine executor:
|
||||
* "Transition to [newState] *if* [actions] execute cleanly. If so, use [continuation] to decide what to do next. If
|
||||
* there was an error it's up to you what to do".
|
||||
* Also see [net.corda.node.services.statemachine.TransitionExecutorImpl] on how this is interpreted.
|
||||
*/
|
||||
data class TransitionResult(
|
||||
val newState: StateMachineState,
|
||||
val actions: List<Action> = emptyList(),
|
||||
val continuation: FlowContinuation = FlowContinuation.ProcessEvents
|
||||
)
|
||||
|
||||
/**
|
||||
* A datastructure describing what to do after a transition has succeeded.
|
||||
*/
|
||||
sealed class FlowContinuation {
|
||||
/**
|
||||
* Return to user code with the supplied [result].
|
||||
*/
|
||||
data class Resume(val result: Any?) : FlowContinuation() {
|
||||
override fun toString() = "Resume(result=${result?.javaClass})"
|
||||
}
|
||||
|
||||
/**
|
||||
* Throw an exception [throwable] in user code.
|
||||
*/
|
||||
data class Throw(val throwable: Throwable) : FlowContinuation()
|
||||
|
||||
/**
|
||||
* Keep processing pending events.
|
||||
*/
|
||||
object ProcessEvents : FlowContinuation() { override fun toString() = "ProcessEvents" }
|
||||
|
||||
/**
|
||||
* Immediately abort the flow. Note that this does not imply an error condition.
|
||||
*/
|
||||
object Abort : FlowContinuation() { override fun toString() = "Abort" }
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
* This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and
|
||||
* initialises the initiated session in case the flow is an initiated one.
|
||||
*/
|
||||
class UnstartedFlowTransition(
|
||||
override val context: TransitionContext,
|
||||
override val startingState: StateMachineState,
|
||||
val unstarted: FlowState.Unstarted
|
||||
) : Transition {
|
||||
override fun transition(): TransitionResult {
|
||||
return builder {
|
||||
if (!currentState.isAnyCheckpointPersisted && !currentState.isStartIdempotent) {
|
||||
createInitialCheckpoint()
|
||||
}
|
||||
|
||||
actions.add(Action.SignalFlowHasStarted(context.id))
|
||||
|
||||
if (unstarted.flowStart is FlowStart.Initiated) {
|
||||
initialiseInitiatedSession(unstarted.flowStart)
|
||||
}
|
||||
|
||||
currentState = currentState.copy(isFlowResumed = true)
|
||||
actions.add(Action.CreateTransaction)
|
||||
FlowContinuation.Resume(null)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialise initiated session, store initial payload, send confirmation back.
|
||||
private fun TransitionBuilder.initialiseInitiatedSession(flowStart: FlowStart.Initiated) {
|
||||
val initiatingMessage = flowStart.initiatingMessage
|
||||
val initiatedState = SessionState.Initiated(
|
||||
peerParty = flowStart.peerSession.counterparty,
|
||||
initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId),
|
||||
peerFlowInfo = FlowInfo(
|
||||
flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion,
|
||||
appName = initiatingMessage.appName
|
||||
),
|
||||
receivedMessages = if (initiatingMessage.firstPayload == null) {
|
||||
emptyList()
|
||||
} else {
|
||||
listOf(DataSessionMessage(initiatingMessage.firstPayload))
|
||||
},
|
||||
errors = emptyList()
|
||||
)
|
||||
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
|
||||
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.copy(
|
||||
sessions = mapOf(flowStart.initiatedSessionId to initiatedState)
|
||||
)
|
||||
)
|
||||
actions.add(
|
||||
Action.SendExisting(
|
||||
flowStart.peerSession.counterparty,
|
||||
sessionMessage,
|
||||
DeduplicationId.createForNormal(currentState.checkpoint, 0)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// Create initial checkpoint and acknowledge triggering messages.
|
||||
private fun TransitionBuilder.createInitialCheckpoint() {
|
||||
actions.addAll(arrayOf(
|
||||
Action.CreateTransaction,
|
||||
Action.PersistCheckpoint(context.id, currentState.checkpoint),
|
||||
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
|
||||
Action.CommitTransaction,
|
||||
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
|
||||
))
|
||||
currentState = currentState.copy(
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isAnyCheckpointPersisted = true
|
||||
)
|
||||
}
|
||||
}
|
@ -197,9 +197,17 @@ class NodeVaultService(
|
||||
if (!netUpdate.isEmpty()) {
|
||||
recordUpdate(netUpdate)
|
||||
mutex.locked {
|
||||
// flowId required by SoftLockManager to perform auto-registration of soft locks for new states
|
||||
// flowId was required by SoftLockManager to perform auto-registration of soft locks for new states
|
||||
val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid
|
||||
val vaultUpdate = if (uuid != null) netUpdate.copy(flowId = uuid) else netUpdate
|
||||
if (uuid != null) {
|
||||
val fungible = netUpdate.produced.filter { it.state.data is FungibleAsset<*> }
|
||||
if (fungible.isNotEmpty()) {
|
||||
val stateRefs = fungible.map { it.ref }.toNonEmptySet()
|
||||
log.trace { "Reserving soft locks for flow id $uuid and states $stateRefs" }
|
||||
softLockReserve(uuid, stateRefs)
|
||||
}
|
||||
}
|
||||
updatesPublisher.onNext(vaultUpdate)
|
||||
}
|
||||
}
|
||||
|
@ -1,58 +0,0 @@
|
||||
package net.corda.node.services.vault
|
||||
|
||||
import net.corda.core.contracts.FungibleAsset
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.node.services.VaultService
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
import java.util.*
|
||||
|
||||
class VaultSoftLockManager private constructor(private val vault: VaultService) {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
@JvmStatic
|
||||
fun install(vault: VaultService, smm: StateMachineManager) {
|
||||
val manager = VaultSoftLockManager(vault)
|
||||
smm.changes.subscribe { change ->
|
||||
if (change is StateMachineManager.Change.Removed) {
|
||||
val logic = change.logic
|
||||
// Don't run potentially expensive query if the flow didn't lock any states:
|
||||
if ((logic.stateMachine as FlowStateMachineImpl<*>).hasSoftLockedStates) {
|
||||
manager.unregisterSoftLocks(logic.runId.uuid, logic)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Discussion
|
||||
//
|
||||
// The intent of the following approach is to support what might be a common pattern in a flow:
|
||||
// 1. Create state
|
||||
// 2. Do something with state
|
||||
// without possibility of another flow intercepting the state between 1 and 2,
|
||||
// since we cannot lock the state before it exists. e.g. Issue and then Move some Cash.
|
||||
//
|
||||
// The downside is we could have a long running flow that holds a lock for a long period of time.
|
||||
// However, the lock can be programmatically released, like any other soft lock,
|
||||
// should we want a long running flow that creates a visible state mid way through.
|
||||
vault.rawUpdates.subscribe { (_, produced, flowId) ->
|
||||
if (flowId != null) {
|
||||
val fungible = produced.filter { it.state.data is FungibleAsset<*> }
|
||||
if (fungible.isNotEmpty()) {
|
||||
manager.registerSoftLocks(flowId, fungible.map { it.ref }.toNonEmptySet())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet<StateRef>) {
|
||||
log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" }
|
||||
vault.softLockReserve(flowId, stateRefs)
|
||||
}
|
||||
|
||||
private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) {
|
||||
log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" }
|
||||
vault.softLockRelease(flowId)
|
||||
}
|
||||
}
|
144
node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt
Normal file
144
node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt
Normal file
@ -0,0 +1,144 @@
|
||||
package net.corda.node.utilities
|
||||
|
||||
import java.lang.reflect.Method
|
||||
import java.lang.reflect.Modifier
|
||||
import java.lang.reflect.Type
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* A tree describing the diff between two objects.
|
||||
*
|
||||
* For example:
|
||||
* data class A(val field1: Int, val field2: String, val field3: Unit)
|
||||
* fun main(args: Array<String>) {
|
||||
* val someA = A(1, "hello", Unit)
|
||||
* val someOtherA = A(2, "bello", Unit)
|
||||
* println(ObjectDiffer.diff(someA, someOtherA))
|
||||
* }
|
||||
*
|
||||
* Will give back Step(branches=[(field1, Last(a=1, b=2)), (field2, Last(a=hello, b=bello))])
|
||||
*/
|
||||
sealed class DiffTree {
|
||||
/**
|
||||
* Describes a "step" from the object root. It contains a list of field-subtree pairs.
|
||||
*/
|
||||
data class Step(val branches: List<Pair<String, DiffTree>>) : DiffTree()
|
||||
|
||||
/**
|
||||
* Describes the leaf of the diff. This is either where the diffing was cutoff (e.g. primitives) or where it failed.
|
||||
*/
|
||||
data class Last(val a: Any?, val b: Any?) : DiffTree()
|
||||
|
||||
/**
|
||||
* Flattens the [DiffTree] into a list of [DiffPath]s
|
||||
*/
|
||||
fun toPaths(): List<DiffPath> {
|
||||
return when (this) {
|
||||
is Step -> branches.flatMap { (step, tree) -> tree.toPaths().map { it.copy(path = listOf(step) + it.path) } }
|
||||
is Last -> listOf(DiffPath(emptyList(), a, b))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A diff focused on a single [DiffTree.Last] diff, including the path leading there.
|
||||
*/
|
||||
data class DiffPath(
|
||||
val path: List<String>,
|
||||
val a: Any?,
|
||||
val b: Any?
|
||||
) {
|
||||
override fun toString(): String {
|
||||
return "${path.joinToString(".")}: \n $a\n $b\n"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a very simple differ used to diff objects of any kind, to be used for diagnostic.
|
||||
*/
|
||||
object ObjectDiffer {
|
||||
fun diff(a: Any?, b: Any?): DiffTree? {
|
||||
if (a == null || b == null) {
|
||||
if (a == b) {
|
||||
return null
|
||||
} else {
|
||||
return DiffTree.Last(a, b)
|
||||
}
|
||||
}
|
||||
if (a != b) {
|
||||
if (a.javaClass.isPrimitive || a.javaClass in diffCutoffClasses) {
|
||||
return DiffTree.Last(a, b)
|
||||
}
|
||||
// TODO deduplicate this code
|
||||
if (a is Map<*, *> && b is Map<*, *>) {
|
||||
val allKeys = a.keys + b.keys
|
||||
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
|
||||
if (branches.isEmpty()) {
|
||||
return null
|
||||
} else {
|
||||
return DiffTree.Step(branches)
|
||||
}
|
||||
}
|
||||
if (a is java.util.Map<*, *> && b is java.util.Map<*, *>) {
|
||||
val allKeys = a.keySet() + b.keySet()
|
||||
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
|
||||
if (branches.isEmpty()) {
|
||||
return null
|
||||
} else {
|
||||
return DiffTree.Step(branches)
|
||||
}
|
||||
}
|
||||
val aFields = getFieldFoci(a)
|
||||
val bFields = getFieldFoci(b)
|
||||
try {
|
||||
if (aFields != bFields) {
|
||||
return DiffTree.Last(a, b)
|
||||
} else {
|
||||
// TODO need to account for cases where the fields don't match up (different subclasses)
|
||||
val branches = aFields.map { field -> diff(field.get(a), field.get(b))?.let { field.name to it } }.filterNotNull()
|
||||
if (branches.isEmpty()) {
|
||||
return DiffTree.Last(a, b)
|
||||
} else {
|
||||
return DiffTree.Step(branches)
|
||||
}
|
||||
}
|
||||
} catch (throwable: Exception) {
|
||||
Exception("Error while diffing $a with $b", throwable).printStackTrace(System.out)
|
||||
return DiffTree.Last(a, b)
|
||||
}
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
// List of types to cutoff the diffing at.
|
||||
private val diffCutoffClasses: Set<Class<*>> = setOf(
|
||||
String::class.java,
|
||||
Class::class.java,
|
||||
Instant::class.java
|
||||
)
|
||||
|
||||
// A type capturing the accessor to a field. This is a separate abstraction to simple reflection as we identify
|
||||
// getX() and isX() calls as fields as well.
|
||||
private data class FieldFocus(val name: String, val type: Type, val getter: Method) {
|
||||
fun get(obj: Any): Any? {
|
||||
return getter.invoke(obj)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getFieldFoci(obj: Any) : List<FieldFocus> {
|
||||
val foci = ArrayList<FieldFocus>()
|
||||
for (method in obj.javaClass.declaredMethods) {
|
||||
if (Modifier.isStatic(method.modifiers)) {
|
||||
continue
|
||||
}
|
||||
if (method.name.startsWith("get") && method.name.length > 3 && method.parameterCount == 0) {
|
||||
val fieldName = method.name[3].toLowerCase() + method.name.substring(4)
|
||||
foci.add(FieldFocus(fieldName, method.returnType, method))
|
||||
} else if (method.name.startsWith("is") && method.parameterCount == 0) {
|
||||
foci.add(FieldFocus(method.name, method.returnType, method))
|
||||
}
|
||||
}
|
||||
return foci
|
||||
}
|
||||
}
|
@ -49,10 +49,10 @@ class InMemoryMessagingTests {
|
||||
|
||||
val bits = "test-content".toByteArray()
|
||||
var finalDelivery: Message? = null
|
||||
node2.network.addMessageHandler("test.topic") { msg, _ ->
|
||||
node2.network.addMessageHandler("test.topic") { msg, _, _ ->
|
||||
node2.network.send(msg, node3.network.myAddress)
|
||||
}
|
||||
node3.network.addMessageHandler("test.topic") { msg, _ ->
|
||||
node3.network.addMessageHandler("test.topic") { msg, _, _ ->
|
||||
finalDelivery = msg
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ class InMemoryMessagingTests {
|
||||
val bits = "test-content".toByteArray()
|
||||
|
||||
var counter = 0
|
||||
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } }
|
||||
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
|
||||
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
|
||||
mockNet.runNetwork(rounds = 1)
|
||||
assertEquals(3, counter)
|
||||
@ -89,9 +89,10 @@ class InMemoryMessagingTests {
|
||||
val node2 = mockNet.createNode()
|
||||
var received = 0
|
||||
|
||||
node1.network.addMessageHandler("valid_message") { _, _ ->
|
||||
node1.network.addMessageHandler("valid_message") { _, _, _ ->
|
||||
received++
|
||||
}
|
||||
|
||||
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1))
|
||||
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1))
|
||||
node2.network.send(invalidMessage, node1.network.myAddress)
|
||||
|
@ -744,6 +744,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
|
||||
private val database: CordaPersistence,
|
||||
private val delegate: WritableTransactionStorage
|
||||
) : WritableTransactionStorage, SingletonSerializeAsToken() {
|
||||
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
|
||||
return database.transaction {
|
||||
delegate.trackTransaction(id)
|
||||
}
|
||||
}
|
||||
|
||||
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
|
||||
return database.transaction {
|
||||
delegate.track()
|
||||
|
@ -1,6 +1,5 @@
|
||||
package net.corda.node.services.events
|
||||
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import com.nhaarman.mockito_kotlin.*
|
||||
import net.corda.core.contracts.*
|
||||
import net.corda.core.crypto.SecureHash
|
||||
@ -22,6 +21,7 @@ import net.corda.testing.internal.doLookup
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.node.MockServices
|
||||
import net.corda.testing.node.TestClock
|
||||
import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.junit.rules.TestWatcher
|
||||
@ -42,8 +42,14 @@ open class NodeSchedulerServiceTestBase {
|
||||
protected val testClock = TestClock(rigorousMock<Clock>().also {
|
||||
doReturn(mark).whenever(it).instant()
|
||||
})
|
||||
private val database = rigorousMock<CordaPersistence>().also {
|
||||
doAnswer {
|
||||
val block: DatabaseTransaction.() -> Any? = uncheckedCast(it.arguments[0])
|
||||
rigorousMock<DatabaseTransaction>().block()
|
||||
}.whenever(it).transaction(any())
|
||||
}
|
||||
protected val flowStarter = rigorousMock<FlowStarter>().also {
|
||||
doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any())
|
||||
doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any(), any())
|
||||
}
|
||||
private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also {
|
||||
doReturn(false).whenever(it).isEnabled()
|
||||
@ -76,7 +82,7 @@ open class NodeSchedulerServiceTestBase {
|
||||
|
||||
protected fun assertStarted(flowLogic: FlowLogic<*>) {
|
||||
// Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow:
|
||||
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any())
|
||||
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any(), any())
|
||||
}
|
||||
|
||||
protected fun assertStarted(event: Event) = assertStarted(event.flowLogic)
|
||||
@ -95,7 +101,6 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
|
||||
database,
|
||||
flowStarter,
|
||||
servicesForResolution,
|
||||
serverThread = MoreExecutors.directExecutor(),
|
||||
flowLogicRefFactory = flowLogicRefFactory,
|
||||
nodeProperties = nodeProperties,
|
||||
drainingModePollPeriod = Duration.ofSeconds(5),
|
||||
@ -209,7 +214,6 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
|
||||
db,
|
||||
flowStarter,
|
||||
servicesForResolution,
|
||||
serverThread = MoreExecutors.directExecutor(),
|
||||
flowLogicRefFactory = flowLogicRefFactory,
|
||||
nodeProperties = nodeProperties,
|
||||
drainingModePollPeriod = Duration.ofSeconds(5),
|
||||
@ -262,6 +266,7 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
|
||||
newDatabase.close()
|
||||
}
|
||||
|
||||
@Ignore("Temporarily")
|
||||
@Test
|
||||
fun `test that if schedule is updated then the flow is invoked on the correct schedule`() {
|
||||
val dataSourceProps = MockServices.makeTestDataSourceProperties()
|
||||
@ -293,4 +298,4 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
|
||||
scheduler.join()
|
||||
database.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -151,7 +151,9 @@ class ArtemisMessagingTest {
|
||||
createMessagingServer().start()
|
||||
|
||||
val messagingClient = createMessagingClient(platformVersion = platformVersion)
|
||||
messagingClient.addMessageHandler(TOPIC) { message, _ ->
|
||||
messagingClient.addMessageHandler(TOPIC) { message, _, handle ->
|
||||
database.transaction { handle.insideDatabaseTransaction() }
|
||||
handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
|
||||
receivedMessages.add(message)
|
||||
}
|
||||
startNodeMessagingClient()
|
||||
|
@ -1,33 +1,40 @@
|
||||
package net.corda.node.services.persistence
|
||||
|
||||
import com.google.common.primitives.Ints
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.node.services.api.Checkpoint
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.transactions.PersistentUniquenessProvider
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.node.internal.configureDatabase
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import net.corda.node.services.statemachine.FlowStart
|
||||
import net.corda.node.services.transactions.PersistentUniquenessProvider
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.testing.internal.LogHelper
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import net.corda.testing.internal.LogHelper
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import kotlin.streams.toList
|
||||
|
||||
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
|
||||
val checkpoints = mutableListOf<Checkpoint>()
|
||||
forEach {
|
||||
checkpoints += it
|
||||
true
|
||||
}
|
||||
return checkpoints
|
||||
internal fun CheckpointStorage.checkpoints(): List<SerializedBytes<Checkpoint>> {
|
||||
val checkpoints = getAllCheckpoints().toList()
|
||||
return checkpoints.map { it.second }
|
||||
}
|
||||
|
||||
class DBCheckpointStorageTests {
|
||||
private companion object {
|
||||
val ALICE = TestIdentity(ALICE_NAME, 70).party
|
||||
}
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule()
|
||||
@ -50,9 +57,9 @@ class DBCheckpointStorageTests {
|
||||
|
||||
@Test
|
||||
fun `add new checkpoint`() {
|
||||
val checkpoint = newCheckpoint()
|
||||
val (id, checkpoint) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpointStorage.addCheckpoint(id, checkpoint)
|
||||
}
|
||||
database.transaction {
|
||||
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||
@ -65,12 +72,12 @@ class DBCheckpointStorageTests {
|
||||
|
||||
@Test
|
||||
fun `remove checkpoint`() {
|
||||
val checkpoint = newCheckpoint()
|
||||
val (id, checkpoint) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpointStorage.addCheckpoint(id, checkpoint)
|
||||
}
|
||||
database.transaction {
|
||||
checkpointStorage.removeCheckpoint(checkpoint)
|
||||
checkpointStorage.removeCheckpoint(id)
|
||||
}
|
||||
database.transaction {
|
||||
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||
@ -83,12 +90,12 @@ class DBCheckpointStorageTests {
|
||||
|
||||
@Test
|
||||
fun `add and remove checkpoint in single commit operate`() {
|
||||
val checkpoint = newCheckpoint()
|
||||
val checkpoint2 = newCheckpoint()
|
||||
val (id, checkpoint) = newCheckpoint()
|
||||
val (id2, checkpoint2) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpointStorage.addCheckpoint(checkpoint2)
|
||||
checkpointStorage.removeCheckpoint(checkpoint)
|
||||
checkpointStorage.addCheckpoint(id, checkpoint)
|
||||
checkpointStorage.addCheckpoint(id2, checkpoint2)
|
||||
checkpointStorage.removeCheckpoint(id)
|
||||
}
|
||||
database.transaction {
|
||||
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
|
||||
@ -101,16 +108,16 @@ class DBCheckpointStorageTests {
|
||||
|
||||
@Test
|
||||
fun `add two checkpoints then remove first one`() {
|
||||
val firstCheckpoint = newCheckpoint()
|
||||
val (id, firstCheckpoint) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(firstCheckpoint)
|
||||
checkpointStorage.addCheckpoint(id, firstCheckpoint)
|
||||
}
|
||||
val secondCheckpoint = newCheckpoint()
|
||||
val (id2, secondCheckpoint) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(secondCheckpoint)
|
||||
checkpointStorage.addCheckpoint(id2, secondCheckpoint)
|
||||
}
|
||||
database.transaction {
|
||||
checkpointStorage.removeCheckpoint(firstCheckpoint)
|
||||
checkpointStorage.removeCheckpoint(id)
|
||||
}
|
||||
database.transaction {
|
||||
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
|
||||
@ -123,9 +130,9 @@ class DBCheckpointStorageTests {
|
||||
|
||||
@Test
|
||||
fun `add checkpoint and then remove after 'restart'`() {
|
||||
val originalCheckpoint = newCheckpoint()
|
||||
val (id, originalCheckpoint) = newCheckpoint()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(originalCheckpoint)
|
||||
checkpointStorage.addCheckpoint(id, originalCheckpoint)
|
||||
}
|
||||
newCheckpointStorage()
|
||||
val reconstructedCheckpoint = database.transaction {
|
||||
@ -135,7 +142,7 @@ class DBCheckpointStorageTests {
|
||||
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
|
||||
}
|
||||
database.transaction {
|
||||
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
|
||||
checkpointStorage.removeCheckpoint(id)
|
||||
}
|
||||
database.transaction {
|
||||
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||
@ -148,7 +155,14 @@ class DBCheckpointStorageTests {
|
||||
}
|
||||
}
|
||||
|
||||
private var checkpointCount = 1
|
||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
|
||||
private fun newCheckpoint(): Pair<StateMachineRunId, SerializedBytes<Checkpoint>> {
|
||||
val id = StateMachineRunId.createRandom()
|
||||
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
|
||||
override fun call() {}
|
||||
}
|
||||
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
|
||||
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, "").getOrThrow()
|
||||
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import co.paralleluniverse.strands.concurrent.Semaphore
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.contracts.ContractState
|
||||
@ -67,10 +68,6 @@ class FlowFrameworkTests {
|
||||
private lateinit var alice: Party
|
||||
private lateinit var bob: Party
|
||||
|
||||
private fun StartedNode<*>.flushSmm() {
|
||||
(this.smm as StateMachineManagerImpl).executor.flush()
|
||||
}
|
||||
|
||||
@Before
|
||||
fun start() {
|
||||
mockNet = InternalMockNetwork(
|
||||
@ -109,6 +106,19 @@ class FlowFrameworkTests {
|
||||
assertThat(flow.lazyTime).isNotNull()
|
||||
}
|
||||
|
||||
class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor {
|
||||
var thrown = false
|
||||
@Suspendable
|
||||
override fun executeAction(fiber: FlowFiber, action: Action) {
|
||||
if (thrown) {
|
||||
delegate.executeAction(fiber, action)
|
||||
} else {
|
||||
thrown = true
|
||||
throw exception
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `exception while fiber suspended`() {
|
||||
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
|
||||
@ -116,16 +126,15 @@ class FlowFrameworkTests {
|
||||
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
|
||||
// Before the flow runs change the suspend action to throw an exception
|
||||
val exceptionDuringSuspend = Exception("Thrown during suspend")
|
||||
fiber.actionOnSuspend = {
|
||||
throw exceptionDuringSuspend
|
||||
}
|
||||
val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
|
||||
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
|
||||
mockNet.runNetwork()
|
||||
assertThatThrownBy {
|
||||
fiber.resultFuture.getOrThrow()
|
||||
}.isSameAs(exceptionDuringSuspend)
|
||||
assertThat(aliceNode.smm.allStateMachines).isEmpty()
|
||||
// Make sure the fiber does actually terminate
|
||||
assertThat(fiber.isTerminated).isTrue()
|
||||
assertThat(fiber.state).isEqualTo(Strand.State.WAITING)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -148,7 +157,6 @@ class FlowFrameworkTests {
|
||||
aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
|
||||
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
|
||||
// Make sure the add() has finished initial processing.
|
||||
bobNode.flushSmm()
|
||||
bobNode.internals.disableDBCloseOnStop()
|
||||
bobNode.dispose() // kill receiver
|
||||
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
|
||||
@ -174,7 +182,6 @@ class FlowFrameworkTests {
|
||||
assertEquals(1, bobNode.checkpointStorage.checkpoints().size)
|
||||
}
|
||||
// Make sure the add() has finished initial processing.
|
||||
bobNode.flushSmm()
|
||||
bobNode.internals.disableDBCloseOnStop()
|
||||
// Restart node and thus reload the checkpoint and resend the message with same UUID
|
||||
bobNode.dispose()
|
||||
@ -187,7 +194,6 @@ class FlowFrameworkTests {
|
||||
val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
|
||||
// Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
|
||||
mockNet.runNetwork()
|
||||
node2b.flushSmm()
|
||||
fut1.getOrThrow()
|
||||
|
||||
val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
|
||||
@ -216,6 +222,8 @@ class FlowFrameworkTests {
|
||||
val payload = "Hello World"
|
||||
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
|
||||
mockNet.runNetwork()
|
||||
bobNode.internals.acceptableLiveFiberCountOnStop = 1
|
||||
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
|
||||
val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
|
||||
val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
|
||||
assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
|
||||
@ -234,9 +242,6 @@ class FlowFrameworkTests {
|
||||
aliceNode sent normalEnd to charlieNode
|
||||
//There's no session end from the other flows as they're manually suspended
|
||||
)
|
||||
|
||||
bobNode.internals.acceptableLiveFiberCountOnStop = 1
|
||||
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -338,7 +343,9 @@ class FlowFrameworkTests {
|
||||
|
||||
mockNet.runNetwork()
|
||||
|
||||
assertThat(erroringFlowSteps.get()).containsExactly(
|
||||
erroringFlowFuture.getOrThrow()
|
||||
val flowSteps = erroringFlowSteps.get()
|
||||
assertThat(flowSteps).containsExactly(
|
||||
Notification.createOnNext(ExceptionFlow.START_STEP),
|
||||
Notification.createOnError(erroringFlowFuture.get().exceptionThrown)
|
||||
)
|
||||
@ -378,8 +385,8 @@ class FlowFrameworkTests {
|
||||
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
||||
}
|
||||
|
||||
assertThat(receivingFiber.isTerminated).isTrue()
|
||||
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue()
|
||||
assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING)
|
||||
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).state).isEqualTo(Strand.State.WAITING)
|
||||
assertThat(erroringFlowSteps.get()).containsExactly(
|
||||
Notification.createOnNext(ExceptionFlow.START_STEP),
|
||||
Notification.createOnError(erroringFlow.get().exceptionThrown)
|
||||
@ -396,7 +403,7 @@ class FlowFrameworkTests {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `FlowException propagated in invocation chain`() {
|
||||
fun `FlowException only propagated to parent`() {
|
||||
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
|
||||
val charlie = charlieNode.info.singleIdentity()
|
||||
|
||||
@ -404,9 +411,8 @@ class FlowFrameworkTests {
|
||||
bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
|
||||
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
|
||||
mockNet.runNetwork()
|
||||
assertThatExceptionOfType(MyFlowException::class.java)
|
||||
assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
|
||||
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
|
||||
.withMessage("Chain")
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -558,10 +564,8 @@ class FlowFrameworkTests {
|
||||
|
||||
@Test
|
||||
fun `customised client flow which has annotated @InitiatingFlow again`() {
|
||||
val result = aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
|
||||
mockNet.runNetwork()
|
||||
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
|
||||
result.getOrThrow()
|
||||
aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
|
||||
}.withMessageContaining(InitiatingFlow::class.java.simpleName)
|
||||
}
|
||||
|
||||
@ -635,24 +639,6 @@ class FlowFrameworkTests {
|
||||
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `double initiateFlow throws`() {
|
||||
val future = aliceNode.services.startFlow(DoubleInitiatingFlow()).resultFuture
|
||||
mockNet.runNetwork()
|
||||
assertThatExceptionOfType(IllegalStateException::class.java)
|
||||
.isThrownBy { future.getOrThrow() }
|
||||
.withMessageContaining("Attempted to initiateFlow() twice")
|
||||
}
|
||||
|
||||
@InitiatingFlow
|
||||
private class DoubleInitiatingFlow : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
initiateFlow(ourIdentity)
|
||||
initiateFlow(ourIdentity)
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//region Helpers
|
||||
|
||||
@ -685,7 +671,6 @@ class FlowFrameworkTests {
|
||||
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
|
||||
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
|
||||
}
|
||||
|
||||
private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
|
||||
private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
|
||||
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
|
||||
@ -694,7 +679,7 @@ class FlowFrameworkTests {
|
||||
private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) {
|
||||
services.networkService.apply {
|
||||
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
|
||||
send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address)
|
||||
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
|
||||
}
|
||||
}
|
||||
|
||||
@ -720,7 +705,7 @@ class FlowFrameworkTests {
|
||||
}
|
||||
|
||||
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
|
||||
return filter { it.getMessage().topic == StateMachineManagerImpl.sessionTopic }.map {
|
||||
return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map {
|
||||
val from = it.sender.id
|
||||
val message = it.messageData.deserialize<SessionMessage>()
|
||||
SessionTransfer(from, sanitise(message), it.recipients)
|
||||
|
@ -5,10 +5,8 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.InputStreamAndHash
|
||||
import net.corda.core.node.ServiceHub
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.node.services.api.StartedNodeServices
|
||||
import net.corda.testing.common.internal.testNetworkParameters
|
||||
import net.corda.testing.contracts.DummyContract
|
||||
import net.corda.testing.contracts.DummyState
|
||||
@ -22,7 +20,6 @@ import net.corda.testing.node.StartedMockNode
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
@ -90,12 +87,11 @@ class MaxTransactionSizeTests {
|
||||
assertEquals(hash1, bigFile1.sha256)
|
||||
SendLargeTransactionFlow(notary, bob, hash1, hash2, hash3, hash4, verify = false)
|
||||
}
|
||||
val ex = assertFailsWith<UnexpectedFlowEndException> {
|
||||
assertFailsWith<UnexpectedFlowEndException> {
|
||||
val future = aliceNode.startFlow(flow)
|
||||
mockNet.runNetwork()
|
||||
future.getOrThrow()
|
||||
}
|
||||
assertThat(ex).hasMessageContaining("Counterparty flow on O=Bob Plc, L=Rome, C=IT had an internal error and has terminated")
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@ -135,4 +131,4 @@ class MaxTransactionSizeTests {
|
||||
otherSide.send(Unit)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -210,7 +210,7 @@ class NotaryServiceTests {
|
||||
|
||||
private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) {
|
||||
aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) {
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
|
||||
val messageData = message.data.deserialize<Any>() as? InitialSessionMessage
|
||||
val payload = messageData?.firstPayload!!.deserialize()
|
||||
|
||||
|
@ -84,9 +84,12 @@ class VaultSoftLockManagerTest {
|
||||
private val mockNet = InternalMockNetwork(cordappPackages = listOf(ContractImpl::class.packageName), defaultFactory = { args ->
|
||||
object : InternalMockNetwork.MockNode(args) {
|
||||
override fun makeVaultService(keyManagementService: KeyManagementService, services: ServicesForResolution, hibernateConfig: HibernateConfiguration): VaultServiceInternal {
|
||||
val node = this
|
||||
val realVault = super.makeVaultService(keyManagementService, services, hibernateConfig)
|
||||
return object : VaultServiceInternal by realVault {
|
||||
override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
|
||||
// Should be called before flow is removed
|
||||
assertEquals(1, node.started!!.smm.allStateMachines.size)
|
||||
mockVault.softLockRelease(lockId, stateRefs) // No need to also call the real one for these tests.
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ ext['artemis.version'] = "$artemis_version"
|
||||
ext['hibernate.version'] = "$hibernate_version"
|
||||
ext['selenium.version'] = "$selenium_version"
|
||||
ext['jackson.version'] = "$jackson_version"
|
||||
ext['dropwizard-metrics.version'] = "$metrics_version"
|
||||
|
||||
apply plugin: 'java'
|
||||
apply plugin: 'kotlin'
|
||||
|
@ -18,6 +18,7 @@ buildscript {
|
||||
ext['artemis.version'] = "$artemis_version"
|
||||
ext['hibernate.version'] = "$hibernate_version"
|
||||
ext['jackson.version'] = "$jackson_version"
|
||||
ext['dropwizard-metrics.version'] = "$metrics_version"
|
||||
|
||||
|
||||
apply plugin: 'java'
|
||||
|
@ -18,10 +18,8 @@ import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.messaging.Message
|
||||
import net.corda.node.services.messaging.MessageHandlerRegistration
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.messaging.*
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.testing.node.internal.InMemoryMessage
|
||||
@ -290,9 +288,16 @@ class InMemoryMessagingNetwork private constructor(
|
||||
private data class InMemoryReceivedMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val platformVersion: Int,
|
||||
override val uniqueMessageId: String,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val debugTimestamp: Instant,
|
||||
override val peer: CordaX500Name) : ReceivedMessage
|
||||
override val peer: CordaX500Name,
|
||||
override val senderUUID: String? = null,
|
||||
override val senderSeqNo: Long? = null,
|
||||
/** Note this flag is never set in the in memory network. */
|
||||
override val isSessionInit: Boolean = false) : ReceivedMessage {
|
||||
|
||||
override val additionalHeaders: Map<String, String> = emptyMap()
|
||||
}
|
||||
|
||||
/**
|
||||
* A class that provides an abstraction over the nodes' messaging service that also contains the ability to
|
||||
@ -319,7 +324,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
private val peerHandle: PeerHandle,
|
||||
private val executor: AffinityExecutor,
|
||||
private val database: CordaPersistence) : SingletonSerializeAsToken(), InternalMockMessagingService {
|
||||
private inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
|
||||
private inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration
|
||||
|
||||
@Volatile
|
||||
private var running = true
|
||||
@ -330,7 +335,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
}
|
||||
|
||||
private val state = ThreadBox(InnerState())
|
||||
private val processedMessages: MutableSet<String> = Collections.synchronizedSet(HashSet<String>())
|
||||
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
|
||||
|
||||
override val myAddress: PeerHandle get() = peerHandle
|
||||
|
||||
@ -353,7 +358,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
}
|
||||
}
|
||||
|
||||
override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
|
||||
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
|
||||
check(running)
|
||||
val (handler, transfers) = state.locked {
|
||||
val handler = Handler(topic, callback).apply { handlers.add(this) }
|
||||
@ -374,7 +379,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
state.locked { check(handlers.remove(registration as Handler)) }
|
||||
}
|
||||
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
|
||||
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
|
||||
check(running)
|
||||
msgSend(this, message, target)
|
||||
if (!sendManuallyPumped) {
|
||||
@ -400,7 +405,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
override fun cancelRedelivery(retryId: Long) {}
|
||||
|
||||
/** Returns the given (topic & session, data) pair as a newly created message object. */
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
|
||||
}
|
||||
|
||||
@ -465,7 +470,7 @@ class InMemoryMessagingNetwork private constructor(
|
||||
database.transaction {
|
||||
for (handler in deliverTo) {
|
||||
try {
|
||||
handler.callback(transfer.toReceivedMessage(), handler)
|
||||
handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler())
|
||||
} catch (e: Exception) {
|
||||
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
|
||||
}
|
||||
@ -489,5 +494,12 @@ class InMemoryMessagingNetwork private constructor(
|
||||
message.debugTimestamp,
|
||||
sender.name)
|
||||
}
|
||||
|
||||
private class DummyDeduplicationHandler : DeduplicationHandler {
|
||||
override fun afterDatabaseTransaction() {
|
||||
}
|
||||
override fun insideDatabaseTransaction() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package net.corda.testing.node.internal
|
||||
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.messaging.Message
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
@ -9,7 +10,11 @@ import java.time.Instant
|
||||
*/
|
||||
data class InMemoryMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val uniqueMessageId: String,
|
||||
override val debugTimestamp: Instant = Instant.now()) : Message {
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val debugTimestamp: Instant = Instant.now(),
|
||||
override val senderUUID: String? = null) : Message {
|
||||
|
||||
override val additionalHeaders: Map<String, String> = emptyMap()
|
||||
|
||||
override fun toString() = "$topic#${String(data.bytes)}"
|
||||
}
|
@ -15,7 +15,6 @@ import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.createDirectories
|
||||
import net.corda.core.internal.createDirectory
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.messaging.SingleMessageRecipient
|
||||
@ -231,7 +230,7 @@ open class InternalMockNetwork(private val cordappPackages: List<String>,
|
||||
private val entropyRoot = args.entropyRoot
|
||||
var counter = entropyRoot
|
||||
override val log get() = staticLog
|
||||
override val serverThread: AffinityExecutor =
|
||||
override val serverThread: AffinityExecutor.ServiceAffinityExecutor =
|
||||
if (mockNet.threadPerNode) {
|
||||
ServiceAffinityExecutor("Mock node $id thread", 1)
|
||||
} else {
|
||||
|
@ -1,18 +1,25 @@
|
||||
package net.corda.testing.node.internal
|
||||
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.internal.concurrent.doneFuture
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SingletonSerializeAsToken
|
||||
import net.corda.core.toFuture
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.node.services.api.WritableTransactionStorage
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.util.HashMap
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* A class which provides an implementation of [WritableTransactionStorage] which is used in [MockServices]
|
||||
*/
|
||||
open class MockTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() {
|
||||
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
|
||||
return txns[id]?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture()
|
||||
}
|
||||
|
||||
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
|
||||
return DataFeed(txns.values.toList(), _updatesPublisher)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user