Merge pull request #2964 from corda/CORDA-1334/aslemmer-enterprise-smm-port

CORDA-1334: port enterprise statemachine
This commit is contained in:
Andras Slemmer 2018-04-23 16:42:10 +01:00 committed by GitHub
commit 640e5c6088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
82 changed files with 4761 additions and 1839 deletions

View File

@ -1200,7 +1200,7 @@ public static final class net.corda.core.flows.FinalityFlow$Companion extends ja
@org.jetbrains.annotations.NotNull public net.corda.core.utilities.ProgressTracker childProgressTracker()
public static final net.corda.core.flows.FinalityFlow$Companion$NOTARISING INSTANCE
##
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException implements net.corda.core.flows.IdentifiableException
public <init>()
public <init>(String)
public <init>(String, Throwable)
@ -1589,9 +1589,10 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec
public int hashCode()
public String toString()
##
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException
public <init>(String)
public <init>(String, Throwable)
public <init>(String, Throwable, Long)
##
@net.corda.core.DoNotImplement @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object
public <init>(java.security.PublicKey)

View File

@ -47,6 +47,7 @@ buildscript {
ext.bouncycastle_version = constants.getProperty("bouncycastleVersion")
ext.guava_version = constants.getProperty("guavaVersion")
ext.caffeine_version = constants.getProperty("caffeineVersion")
ext.metrics_version = constants.getProperty("metricsVersion")
ext.okhttp_version = '3.5.0'
ext.netty_version = '4.1.9.Final'
ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion")

View File

@ -8,3 +8,4 @@ jsr305Version=3.0.2
artifactoryPluginVersion=4.4.18
snakeYamlVersion=1.19
caffeineVersion=2.6.2
metricsVersion=3.2.5

View File

@ -0,0 +1,16 @@
package net.corda.core.flows;
import javax.annotation.Nullable;
/**
* An exception that may be identified with an ID. If an exception originates in a counter-flow this ID will be
* propagated. This allows correlation of error conditions across different flows.
*/
public interface IdentifiableException {
/**
* @return the ID of the error, or null if the error doesn't have it set (yet).
*/
default @Nullable Long getErrorId() {
return null;
}
}

View File

@ -7,16 +7,27 @@ import net.corda.core.CordaRuntimeException
/**
* Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end.
* The exception will propagate to all counterparty flows and will be thrown on their end the next time they wait on a
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already ended,
* will not receive the exception (if this is required then have them wait for a confirmation message).
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already
* ended, will not receive the exception (if this is required then have them wait for a confirmation message).
*
* If the *rethrown* [FlowException] is uncaught in counterparty flows and propagation triggers then the exception is
* downgraded to an [UnexpectedFlowEndException]. This means only immediate counterparty flows will receive information
* about what the exception was.
*
* [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service.
* It is recommended a [FlowLogic] document the [FlowException] types it can throw.
*
* @property originalErrorId the ID backing [getErrorId]. If null it will be set dynamically by the flow framework when
* the exception is handled. This ID is propagated to counterparty flows, even when the [FlowException] is
* downgraded to an [UnexpectedFlowEndException]. This is so the error conditions may be correlated later on.
*/
open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) {
open class FlowException(message: String?, cause: Throwable?) :
CordaException(message, cause), IdentifiableException {
constructor(message: String?) : this(message, null)
constructor(cause: Throwable?) : this(cause?.toString(), cause)
constructor() : this(null, null)
var originalErrorId: Long? = null
override fun getErrorId(): Long? = originalErrorId
}
// DOCEND 1
@ -25,6 +36,9 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m
* that we were not expecting), or the other side had an internal error, or the other side terminated when we
* were waiting for a response.
*/
class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
constructor(msg: String) : this(msg, null)
}
class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long?) :
CordaRuntimeException(message, cause), IdentifiableException {
constructor(message: String, cause: Throwable?) : this(message, cause, null)
constructor(message: String) : this(message, null)
override fun getErrorId(): Long? = originalErrorId
}

View File

@ -6,20 +6,17 @@ import net.corda.core.CordaInternal
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.*
import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.debug
import net.corda.core.utilities.*
import org.slf4j.Logger
import java.time.Duration
import java.time.Instant
/**
* A sub-class of [FlowLogic<T>] implements a flow using direct, straight line blocking code. Thus you
@ -77,12 +74,19 @@ abstract class FlowLogic<out T> {
*/
@Suspendable
@JvmStatic
@JvmOverloads
@Throws(FlowException::class)
fun sleep(duration: Duration) {
fun sleep(duration: Duration, maySkipCheckpoint: Boolean = false) {
if (duration > Duration.ofMinutes(5)) {
throw FlowException("Attempt to sleep for longer than 5 minutes is not supported. Consider using SchedulableState.")
}
(Strand.currentStrand() as? FlowStateMachine<*>)?.sleepUntil(Instant.now() + duration) ?: Strand.sleep(duration.toMillis())
val fiber = (Strand.currentStrand() as? FlowStateMachine<*>)
if (fiber == null) {
Strand.sleep(duration.toMillis())
} else {
val request = FlowIORequest.Sleep(wakeUpAfter = fiber.serviceHub.clock.instant() + duration)
fiber.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
}
}
}
@ -94,7 +98,7 @@ abstract class FlowLogic<out T> {
/**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
* only available once the flow has started, which means it cannot be accessed in the constructor. Either
* access this lazily or from inside [call].
*/
val serviceHub: ServiceHub get() = stateMachine.serviceHub
@ -104,7 +108,7 @@ abstract class FlowLogic<out T> {
* that this function does not communicate in itself, the counter-flow will be kicked off by the first send/receive.
*/
@Suspendable
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions)
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party)
/**
* Specifies the identity, with certificate, to use for this flow. This will be one of the multiple identities that
@ -114,7 +118,10 @@ abstract class FlowLogic<out T> {
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
* is implemented.
*/
val ourIdentityAndCert: PartyAndCertificate get() = stateMachine.ourIdentityAndCert
val ourIdentityAndCert: PartyAndCertificate get() {
return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity }
?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!")
}
/**
* Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node.
@ -124,8 +131,14 @@ abstract class FlowLogic<out T> {
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
* is implemented.
*/
val ourIdentity: Party get() = ourIdentityAndCert.party
val ourIdentity: Party get() = stateMachine.ourIdentity
// Used to implement the deprecated send/receive functions using Party. When such a deprecated function is used we
// create a fresh session for the Party, put it here and use it in subsequent deprecated calls.
private val deprecatedPartySessionMap = HashMap<Party, FlowSession>()
private fun getDeprecatedSessionForParty(party: Party): FlowSession {
return deprecatedPartySessionMap.getOrPut(party) { initiateFlow(party) }
}
/**
* Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it
* provides the necessary information needed for the evolution of flows and enabling backwards compatibility.
@ -133,9 +146,9 @@ abstract class FlowLogic<out T> {
* This method can be called before any send or receive has been done with [otherParty]. In such a case this will force
* them to start their flow.
*/
@Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING)
@Deprecated("Use FlowSession.getCounterpartyFlowInfo()", level = DeprecationLevel.WARNING)
@Suspendable
fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false)
fun getFlowInfo(otherParty: Party): FlowInfo = getDeprecatedSessionForParty(otherParty).getCounterpartyFlowInfo()
/**
* Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response
@ -169,31 +182,7 @@ abstract class FlowLogic<out T> {
@Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING)
@Suspendable
open fun <R : Any> sendAndReceive(receiveType: Class<R>, otherParty: Party, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false)
}
/**
* Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received.
*
* Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead.
* It is only intended for the case where the [otherParty] is running a distributed service with an idempotent
* flow which only accepts a single request and sends back a single response e.g. a notary or certain types of
* oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered
* to a different one, so there is no need to wait until the initial node comes back up to obtain a response.
*/
@Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING)
internal inline fun <reified R : Any> sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
}
@Suspendable
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
}
@Suspendable
internal inline fun <reified R : Any> FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
return getDeprecatedSessionForParty(otherParty).sendAndReceive(receiveType, payload)
}
/**
@ -218,9 +207,37 @@ abstract class FlowLogic<out T> {
@Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING)
@Suspendable
open fun <R : Any> receive(receiveType: Class<R>, otherParty: Party): UntrustworthyData<R> {
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false)
return getDeprecatedSessionForParty(otherParty).receive(receiveType)
}
/**
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
*
* Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty]
* is offline then message delivery will be retried until it comes back or until the message is older than the
* network's event horizon time.
*/
@Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING)
@Suspendable
open fun send(otherParty: Party, payload: Any) {
getDeprecatedSessionForParty(otherParty).send(payload)
}
@Suspendable
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
val request = FlowIORequest.SendAndReceive(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = true
)
return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType)
}
@Suspendable
internal inline fun <reified R : Any> FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData<R> {
return sendAndReceiveWithRetry(R::class.java, payload)
}
/** Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
@ -232,8 +249,14 @@ abstract class FlowLogic<out T> {
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
*/
@Suspendable
open fun receiveAllMap(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
return stateMachine.receiveAll(sessions, this)
@JvmOverloads
open fun receiveAllMap(sessions: Map<FlowSession, Class<out Any>>, maySkipCheckpoint: Boolean = false): Map<FlowSession, UntrustworthyData<Any>> {
enforceNoPrimitiveInReceive(sessions.values)
val replies = stateMachine.suspend(
ioRequest = FlowIORequest.Receive(sessions.keys.toNonEmptySet()),
maySkipCheckpoint = maySkipCheckpoint
)
return replies.mapValues { (session, payload) -> payload.checkPayloadIs(sessions[session]!!) }
}
/**
@ -248,24 +271,13 @@ abstract class FlowLogic<out T> {
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
*/
@Suspendable
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
@JvmOverloads
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>, maySkipCheckpoint: Boolean = false): List<UntrustworthyData<R>> {
enforceNoPrimitiveInReceive(listOf(receiveType))
enforceNoDuplicates(sessions)
return castMapValuesToKnownType(receiveAllMap(associateSessionsToReceiveType(receiveType, sessions)))
}
/**
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
*
* Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty]
* is offline then message delivery will be retried until it comes back or until the message is older than the
* network's event horizon time.
*/
@Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING)
@Suspendable
open fun send(otherParty: Party, payload: Any) {
stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false)
}
/**
* Invokes the given subflow. This function returns once the subflow completes successfully with the result
* returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the
@ -283,11 +295,8 @@ abstract class FlowLogic<out T> {
open fun <R> subFlow(subLogic: FlowLogic<R>): R {
subLogic.stateMachine = stateMachine
maybeWireUpProgressTracking(subLogic)
if (!subLogic.javaClass.isAnnotationPresent(InitiatingFlow::class.java)) {
subLogic.flowUsedForSessions = flowUsedForSessions
}
logger.debug { "Calling subflow: $subLogic" }
val result = subLogic.call()
val result = stateMachine.subFlow(subLogic)
logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" }
// It's easy to forget this when writing flows so we just step it to the DONE state when it completes.
subLogic.progressTracker?.currentStep = ProgressTracker.DONE
@ -384,7 +393,8 @@ abstract class FlowLogic<out T> {
@Suspendable
@JvmOverloads
fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction {
return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint)
val request = FlowIORequest.WaitForLedgerCommit(hash)
return stateMachine.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
}
/**
@ -427,11 +437,6 @@ abstract class FlowLogic<out T> {
_stateMachine = value
}
// This is the flow used for managing sessions. It defaults to the current flow but if this is an inlined sub-flow
// then it will point to the flow it's been inlined to.
@Suppress("LeakingThis")
private var flowUsedForSessions: FlowLogic<*> = this
private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) {
val ours = progressTracker
val theirs = subLogic.progressTracker
@ -448,6 +453,11 @@ abstract class FlowLogic<out T> {
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
}
private fun enforceNoPrimitiveInReceive(types: Collection<Class<*>>) {
val primitiveTypes = types.filter { it.isPrimitive }
require(primitiveTypes.isEmpty()) { "Cannot receive primitive type(s) $primitiveTypes" }
}
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
}
@ -472,4 +482,4 @@ data class FlowInfo(
* to deduplicate it from other releases of the same CorDapp, typically a version string. See the
* [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details.
*/
val appName: String)
val appName: String)

View File

@ -0,0 +1,23 @@
package net.corda.core.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.concurrent.CordaFuture
import net.corda.core.flows.FlowLogic
import net.corda.core.serialization.CordaSerializable
/**
* Interface for arbitrary operations that can be invoked in a flow asynchronously - the flow will suspend until the
* operation completes. Operation parameters are expected to be injected via constructor.
*/
@CordaSerializable
interface FlowAsyncOperation<R : Any> {
/** Performs the operation in a non-blocking fashion. */
fun execute(): CordaFuture<R>
}
/** Executes the specified [operation] and suspends until operation completion. */
@Suspendable
fun <T, R : Any> FlowLogic<T>.executeAsync(operation: FlowAsyncOperation<R>, maySkipCheckpoint: Boolean = false): R {
val request = FlowIORequest.ExecuteAsyncOperation(operation)
return stateMachine.suspend(request, maySkipCheckpoint)
}

View File

@ -0,0 +1,89 @@
package net.corda.core.internal
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowSession
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
import java.time.Instant
/**
* A [FlowIORequest] represents an IO request of a flow when it suspends. It is persisted in checkpoints.
*/
sealed class FlowIORequest<out R : Any> {
/**
* Send messages to sessions.
*
* @property sessionToMessage a map from session to message-to-be-sent.
* @property shouldRetrySend specifies whether the send should be retried.
*/
data class Send(
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
val shouldRetrySend: Boolean
) : FlowIORequest<Unit>() {
override fun toString() = "Send(" +
"sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " +
"shouldRetrySend=$shouldRetrySend" +
")"
}
/**
* Receive messages from sessions.
*
* @property sessions the sessions to receive messages from.
* @return a map from session to received message.
*/
data class Receive(
val sessions: NonEmptySet<FlowSession>
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>()
/**
* Send and receive messages from the specified sessions.
*
* @property sessionToMessage a map from session to message-to-be-sent. The keys also specify which sessions to
* receive from.
* @property shouldRetrySend specifies whether the send should be retried.
* @return a map from session to received message.
*/
data class SendAndReceive(
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
val shouldRetrySend: Boolean
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>() {
override fun toString() = "SendAndReceive(${sessionToMessage.mapValues { (key, value) ->
"$key=${value.hash}" }}, shouldRetrySend=$shouldRetrySend)"
}
/**
* Wait for a transaction to be committed to the database.
*
* @property hash the hash of the transaction.
* @return the committed transaction.
*/
data class WaitForLedgerCommit(val hash: SecureHash) : FlowIORequest<SignedTransaction>()
/**
* Get the FlowInfo of the specified sessions.
*
* @property sessions the sessions to get the FlowInfo of.
* @return a map from session to FlowInfo.
*/
data class GetFlowInfo(val sessions: NonEmptySet<FlowSession>) : FlowIORequest<Map<FlowSession, FlowInfo>>()
/**
* Suspend the flow until the specified time.
*
* @property wakeUpAfter the time to sleep until.
*/
data class Sleep(val wakeUpAfter: Instant) : FlowIORequest<Unit>()
/**
* Suspend the flow until all Initiating sessions are confirmed.
*/
object WaitForSessionConfirmations : FlowIORequest<Unit>()
/**
* Execute the specified [operation], suspend the flow until completion.
*/
data class ExecuteAsyncOperation<T : Any>(val operation: FlowAsyncOperation<T>) : FlowIORequest<T>()
}

View File

@ -1,64 +1,42 @@
package net.corda.core.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.DoNotImplement
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.context.InvocationContext
import net.corda.core.node.ServiceHub
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.UntrustworthyData
import org.slf4j.Logger
import java.time.Instant
/** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */
interface FlowStateMachine<R> {
@DoNotImplement
interface FlowStateMachine<FLOWRETURN> {
@Suspendable
fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo
fun <SUSPENDRETURN : Any> suspend(ioRequest: FlowIORequest<SUSPENDRETURN>, maySkipCheckpoint: Boolean): SUSPENDRETURN
@Suspendable
fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession
@Suspendable
fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
sessionFlow: FlowLogic<*>,
retrySend: Boolean,
maySkipCheckpoint: Boolean): UntrustworthyData<T>
@Suspendable
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData<T>
@Suspendable
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean)
@Suspendable
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction
@Suspendable
fun sleepUntil(until: Instant)
fun initiateFlow(party: Party): FlowSession
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
@Suspendable
fun <SUBFLOWRETURN> subFlow(subFlow: FlowLogic<SUBFLOWRETURN>): SUBFLOWRETURN
@Suspendable
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
@Suspendable
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
val logic: FlowLogic<R>
val logic: FlowLogic<FLOWRETURN>
val serviceHub: ServiceHub
val logger: Logger
val id: StateMachineRunId
val resultFuture: CordaFuture<R>
val resultFuture: CordaFuture<FLOWRETURN>
val context: InvocationContext
val ourIdentityAndCert: PartyAndCertificate
@Suspendable
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
val ourIdentity: Party
}

View File

@ -1,6 +1,7 @@
package net.corda.core.node.services
import net.corda.core.DoNotImplement
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.messaging.DataFeed
import net.corda.core.transactions.SignedTransaction
@ -26,4 +27,9 @@ interface TransactionStorage {
* Returns all currently stored transactions and further fresh ones.
*/
fun track(): DataFeed<List<SignedTransaction>, SignedTransaction>
/**
* Returns a future that completes with the transaction corresponding to [id] once it has been committed
*/
fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction>
}

View File

@ -2,6 +2,9 @@ package net.corda.core.utilities
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowException
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import java.io.Serializable
/**
@ -29,3 +32,15 @@ class UntrustworthyData<out T>(@PublishedApi internal val fromUntrustedWorld: T)
}
inline fun <T, R> UntrustworthyData<T>.unwrap(validator: (T) -> R): R = validator(fromUntrustedWorld)
fun <T : Any> SerializedBytes<Any>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize(this, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IllegalArgumentException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw IllegalArgumentException("We were expecting a ${type.name} but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -61,9 +61,8 @@ public class FlowsInJavaTest {
fail("ExecutionException should have been thrown");
} catch (ExecutionException e) {
assertThat(e.getCause())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("primitive")
.hasMessageContaining(receiveType.getName());
.hasMessageContaining(Primitives.unwrap(receiveType).getName());
}
}

View File

@ -107,7 +107,7 @@ dependencies {
}
// Coda Hale's Metrics: for monitoring of key statistics
compile "io.dropwizard.metrics:metrics-core:3.1.2"
compile "io.dropwizard.metrics:metrics-core:$metrics_version"
// JimFS: in memory java.nio filesystem. Used for test and simulation utilities.
compile "com.google.jimfs:jimfs:1.1"

View File

@ -144,7 +144,7 @@ class P2PMessagingTest {
distributedServiceNodes.forEach {
val nodeName = it.services.myInfo.legalIdentitiesAndCerts.first().name
it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _ ->
it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
crashingNodes.requestsReceived.incrementAndGet()
crashingNodes.firstRequestReceived.countDown()
// The node which receives the first request will ignore all requests
@ -159,6 +159,7 @@ class P2PMessagingTest {
val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes)
it.internalServices.networkService.send(response, request.replyTo)
}
handler.afterDatabaseTransaction()
}
}
return crashingNodes
@ -186,10 +187,11 @@ class P2PMessagingTest {
}
private fun InProcess.respondWith(message: Any) {
internalServices.networkService.addMessageHandler("test.request") { netMessage, _ ->
internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
val request = netMessage.data.deserialize<TestRequest>()
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes)
internalServices.networkService.send(response, request.replyTo)
handler.afterDatabaseTransaction()
}
}
@ -211,11 +213,12 @@ class P2PMessagingTest {
*/
inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topic) { msg, reg ->
addMessageHandler(topic) { msg, reg, handler ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
callback(msg)
handler.afterDatabaseTransaction()
}
}

View File

@ -8,6 +8,7 @@ import net.corda.confidential.SwapIdentitiesHandler
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.newSecureRandom
import net.corda.core.crypto.sign
import net.corda.core.flows.*
import net.corda.core.identity.CordaX500Name
@ -47,6 +48,7 @@ import net.corda.node.services.events.NodeSchedulerService
import net.corda.node.services.events.ScheduledActivityObserver
import net.corda.node.services.identity.PersistentIdentityService
import net.corda.node.services.keys.PersistentKeyManagementService
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.*
import net.corda.node.services.persistence.*
@ -56,7 +58,6 @@ import net.corda.node.services.statemachine.*
import net.corda.node.services.transactions.*
import net.corda.node.services.upgrade.ContractUpgradeServiceImpl
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSoftLockManager
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.JVMAgentRegistry
import net.corda.node.utilities.NamedThreadFactory
@ -131,7 +132,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
// We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the
// low-performance prototyping period.
protected abstract val serverThread: AffinityExecutor
protected abstract val serverThread: AffinityExecutor.ServiceAffinityExecutor
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val flowFactories = ConcurrentHashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
@ -248,7 +249,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
flowStarter,
servicesForResolution,
unfinishedSchedules = busyNodeLatch,
serverThread = serverThread,
flowLogicRefFactory = flowLogicRefFactory,
drainingModePollPeriod = configuration.drainingModePollPeriod,
nodeProperties = nodeProperties)
@ -385,11 +385,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
protected abstract fun myAddresses(): List<NetworkHostAndPort>
protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager {
return StateMachineManagerImpl(
return SingleThreadedStateMachineManager(
services,
checkpointStorage,
serverThread,
database,
newSecureRandom(),
busyNodeLatch,
cordappLoader.appClassLoader
)
@ -639,7 +640,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
protected open fun makeTransactionStorage(database: CordaPersistence, transactionCacheSizeBytes: Long): WritableTransactionStorage = DBTransactionStorage(transactionCacheSizeBytes)
private fun makeVaultObservers(schedulerService: SchedulerService, hibernateConfig: HibernateConfiguration, smm: StateMachineManager, schemaService: SchemaService, flowLogicRefFactory: FlowLogicRefFactory) {
VaultSoftLockManager.install(services.vaultService, smm)
ScheduledActivityObserver.install(services.vaultService, schedulerService, flowLogicRefFactory)
HibernateObserver.install(services.vaultService.rawUpdates, hibernateConfig, schemaService)
}
@ -894,8 +894,8 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
}
internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> {
return serverThread.fetchFrom { smm.startFlow(logic, context) }
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler?): CordaFuture<FlowStateMachine<T>> {
return smm.startFlow(logic, context, ourIdentity = null, deduplicationHandler = deduplicationHandler)
}
override fun <T> invokeFlowAsync(

View File

@ -1,42 +1,28 @@
package net.corda.node.services.api
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.Checkpoint
import java.util.stream.Stream
/**
* Thread-safe storage of fiber checkpoints.
*/
interface CheckpointStorage {
/**
* Add a new checkpoint to the store.
*/
fun addCheckpoint(checkpoint: Checkpoint)
fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>)
/**
* Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist
* in the store. Doing so will throw an [IllegalArgumentException].
* Remove existing checkpoint from the store.
* @return whether the id matched a checkpoint that was removed.
*/
fun removeCheckpoint(checkpoint: Checkpoint)
fun removeCheckpoint(id: StateMachineRunId): Boolean
/**
* Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
* The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
* Return false from the block to terminate further iteration.
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
* underlying database connection is closed, so any processing should happen before it is closed.
*/
fun forEach(block: (Checkpoint) -> Boolean)
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
class Checkpoint(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
val id: SecureHash get() = serializedFiber.hash
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
override fun hashCode(): Int = id.hashCode()
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>>
}

View File

@ -19,6 +19,7 @@ import net.corda.core.utilities.contextLogger
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.cordapp.CordappProviderInternal
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.NetworkMapUpdater
import net.corda.node.services.statemachine.FlowStateMachineImpl
@ -135,8 +136,9 @@ interface FlowStarter {
/**
* Starts an already constructed flow. Note that you must be on the server thread to call this method.
* @param context indicates who started the flow, see: [InvocationContext].
* @param deduplicationHandler allows exactly-once start of the flow, see [DeduplicationHandler]
*/
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>>
fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext, deduplicationHandler: DeduplicationHandler? = null): CordaFuture<FlowStateMachine<T>>
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.

View File

@ -26,10 +26,12 @@ import net.corda.node.MutableClock
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.NodePropertiesStore
import net.corda.node.services.api.SchedulerService
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.utilities.PersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.apache.activemq.artemis.utils.ReusableLatch
import org.apache.mina.util.ConcurrentHashSet
import org.slf4j.Logger
import java.io.Serializable
import java.time.Duration
@ -61,7 +63,6 @@ class NodeSchedulerService(private val clock: CordaClock,
private val flowStarter: FlowStarter,
private val servicesForResolution: ServicesForResolution,
private val unfinishedSchedules: ReusableLatch = ReusableLatch(),
private val serverThread: Executor,
private val flowLogicRefFactory: FlowLogicRefFactory,
private val nodeProperties: NodePropertiesStore,
private val drainingModePollPeriod: Duration,
@ -164,6 +165,10 @@ class NodeSchedulerService(private val clock: CordaClock,
var rescheduled: GuavaSettableFuture<Boolean>? = null
}
// Used to de-duplicate flow starts in case a flow is starting but the corresponding entry hasn't been removed yet
// from the database
private val startingStateRefs = ConcurrentHashSet<ScheduledStateRef>()
private val mutex = ThreadBox(InnerState())
// We need the [StateMachineManager] to be constructed before this is called in case it schedules a flow.
fun start() {
@ -173,6 +178,29 @@ class NodeSchedulerService(private val clock: CordaClock,
}
}
/**
* Stop scheduler service.
*/
fun stop() {
mutex.locked {
schedulerTimerExecutor.shutdown()
scheduledStatesQueue.clear()
scheduledStates.clear()
}
}
/**
* Resume scheduler service after having called [stop].
*/
fun resume() {
mutex.locked {
schedulerTimerExecutor = Executors.newSingleThreadExecutor()
scheduledStates.putAll(createMap())
scheduledStatesQueue.addAll(scheduledStates.values)
rescheduleWakeUp()
}
}
override fun scheduleStateActivity(action: ScheduledStateRef) {
log.trace { "Schedule $action" }
val previousState = scheduledStates[action.ref]
@ -181,7 +209,7 @@ class NodeSchedulerService(private val clock: CordaClock,
val previousEarliest = scheduledStatesQueue.peek()
scheduledStatesQueue.remove(previousState)
scheduledStatesQueue.add(action)
if (previousState == null) {
if (previousState == null && action !in startingStateRefs) {
unfinishedSchedules.countUp()
}
@ -212,7 +240,7 @@ class NodeSchedulerService(private val clock: CordaClock,
}
}
private val schedulerTimerExecutor = Executors.newSingleThreadExecutor()
private var schedulerTimerExecutor = Executors.newSingleThreadExecutor()
/**
* This method first cancels the [java.util.concurrent.Future] for any pending action so that the
* [awaitWithDeadline] used below drops through without running the action. We then create a new
@ -254,25 +282,41 @@ class NodeSchedulerService(private val clock: CordaClock,
schedulerTimerExecutor.join()
}
private inner class FlowStartDeduplicationHandler(val scheduledState: ScheduledStateRef) : DeduplicationHandler {
override fun insideDatabaseTransaction() {
scheduledStates.remove(scheduledState.ref)
}
override fun afterDatabaseTransaction() {
startingStateRefs.remove(scheduledState)
}
override fun toString(): String {
return "${javaClass.simpleName}($scheduledState)"
}
}
private fun onTimeReached(scheduledState: ScheduledStateRef) {
serverThread.execute {
var flowName: String? = "(unknown)"
try {
database.transaction {
val scheduledFlow = getScheduledFlow(scheduledState)
if (scheduledFlow != null) {
flowName = scheduledFlow.javaClass.name
// TODO refactor the scheduler to store and propagate the original invocation context
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture }
future.then {
unfinishedSchedules.countDown()
}
var flowName: String? = "(unknown)"
try {
// We need to check this before the database transaction, otherwise there is a subtle race between a
// doubly-reached deadline and the removal from [startingStateRefs].
if (scheduledState !in startingStateRefs) {
val scheduledFlow = database.transaction { getScheduledFlow(scheduledState) }
if (scheduledFlow != null) {
startingStateRefs.add(scheduledState)
flowName = scheduledFlow.javaClass.name
// TODO refactor the scheduler to store and propagate the original invocation context
val context = InvocationContext.newInstance(InvocationOrigin.Scheduled(scheduledState))
val deduplicationHandler = FlowStartDeduplicationHandler(scheduledState)
val future = flowStarter.startFlow(scheduledFlow, context, deduplicationHandler).flatMap { it.resultFuture }
future.then {
unfinishedSchedules.countDown()
}
}
} catch (e: Exception) {
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
}
} catch (e: Exception) {
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
}
}
@ -304,7 +348,6 @@ class NodeSchedulerService(private val clock: CordaClock,
}
else -> {
log.trace { "Scheduler starting FlowLogic $flowLogic" }
scheduledStates.remove(scheduledState.ref)
scheduledStatesQueue.remove(scheduledState)
flowLogic
}
@ -328,4 +371,4 @@ class NodeSchedulerService(private val clock: CordaClock,
null
}
}
}
}

View File

@ -1,6 +1,7 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
@ -8,8 +9,8 @@ import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.statemachine.DeduplicationId
import java.time.Instant
import java.util.*
import javax.annotation.concurrent.ThreadSafe
/**
@ -35,7 +36,7 @@ interface MessagingService {
*
* @param topic identifier for the topic to listen for messages arriving on.
*/
fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration
/**
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
@ -66,8 +67,7 @@ interface MessagingService {
message: Message,
target: MessageRecipients,
retryId: Long? = null,
sequenceKey: Any = target,
additionalHeaders: Map<String, String> = emptyMap()
sequenceKey: Any = target
)
/** A message with a target and sequenceKey specified. */
@ -97,7 +97,7 @@ interface MessagingService {
* @param additionalProperties optional additional message headers.
* @param topic identifier for the topic the message is sent to.
*/
fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message
fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), additionalHeaders: Map<String, String> = emptyMap()): Message
/** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -106,9 +106,8 @@ interface MessagingService {
val myAddress: SingleMessageRecipient
}
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null)
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId)
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap())
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId)
interface MessageHandlerRegistration
@ -127,7 +126,9 @@ interface Message {
val topic: String
val data: ByteSequence
val debugTimestamp: Instant
val uniqueMessageId: String
val uniqueMessageId: DeduplicationId
val senderUUID: String?
val additionalHeaders: Map<String, String>
}
// TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that.
@ -138,6 +139,10 @@ interface ReceivedMessage : Message {
val peer: CordaX500Name
/** Platform version of the sender's node. */
val platformVersion: Int
/** Sequence number of message with respect to senderUUID */
val senderSeqNo: Long?
/** True if a flow session init message */
val isSessionInit: Boolean
}
/** A singleton that's useful for validating topic strings */
@ -147,3 +152,29 @@ object TopicStringValidator {
fun check(tag: String) = require(regex.matcher(tag).matches())
}
/**
* This handler is used to implement exactly-once delivery of an event on top of a possibly duplicated one. This is done
* using two hooks that are called from the event processor, one called from the database transaction committing the
* side-effect caused by the event, and another one called after the transaction has committed successfully.
*
* For example for messaging we can use [insideDatabaseTransaction] to store the message's unique ID for later
* deduplication, and [afterDatabaseTransaction] to acknowledge the message and stop retries.
*
* We also use this for exactly-once start of a scheduled flow, [insideDatabaseTransaction] is used to remove the
* to-be-scheduled state of the flow, [afterDatabaseTransaction] is used for cleanup of in-memory bookkeeping.
*/
interface DeduplicationHandler {
/**
* This will be run inside a database transaction that commits the side-effect of the event, allowing the
* implementor to persist the event delivery fact atomically with the side-effect.
*/
fun insideDatabaseTransaction()
/**
* This will be run strictly after the side-effect has been committed successfully and may be used for
* cleanup/acknowledgement/stopping of retries.
*/
fun afterDatabaseTransaction()
}
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, DeduplicationHandler) -> Unit

View File

@ -0,0 +1,86 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.SettableFuture
import com.codahale.metrics.MetricRegistry
import net.corda.core.messaging.MessageRecipients
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.VersionInfo
import net.corda.node.services.statemachine.FlowMessagingImpl
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import org.apache.activemq.artemis.api.core.ActiveMQDuplicateIdException
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.ExecutionException
import java.util.concurrent.atomic.AtomicLong
import kotlin.concurrent.thread
interface AddressToArtemisQueueResolver {
/**
* Resolves a [MessageRecipients] to an Artemis queue name, creating the underlying queue if needed.
*/
fun resolveTargetToArtemisQueue(address: MessageRecipients): String
}
/**
* The [MessagingExecutor] is responsible for handling send and acknowledge jobs. It batches them using a bounded
* blocking queue, submits the jobs asynchronously and then waits for them to flush using [ClientSession.commit].
* Note that even though we buffer in theory this shouldn't increase latency as the executor is immediately woken up if
* it was waiting. The number of jobs in the queue is only ever greater than 1 if the commit takes a long time.
*/
class MessagingExecutor(
val session: ClientSession,
val producer: ClientProducer,
val versionInfo: VersionInfo,
val resolver: AddressToArtemisQueueResolver,
val ourSenderUUID: String
) {
private val cordaVendor = SimpleString(versionInfo.vendor)
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
private val ourSenderSeqNo = AtomicLong()
private companion object {
val log = contextLogger()
val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
}
fun send(message: Message, target: MessageRecipients) {
val mqAddress = resolver.resolveTargetToArtemisQueue(target)
val artemisMessage = cordaToArtemisMessage(message)
log.trace {
"Send to: $mqAddress topic: ${message.topic} " +
"sessionID: ${message.topic} id: ${message.uniqueMessageId}"
}
producer.send(SimpleString(mqAddress), artemisMessage)
}
fun acknowledge(message: ClientMessage) {
message.individualAcknowledge()
}
internal fun cordaToArtemisMessage(message: Message): ClientMessage? {
return session.createMessage(true).apply {
putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor)
putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion)
putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion)
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
// If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut.
if (ourSenderUUID == message.senderUUID) {
putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID))
putLongProperty(P2PMessagingHeaders.senderSeqNo, ourSenderSeqNo.getAndIncrement())
}
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) {
putLongProperty(org.apache.activemq.artemis.api.core.Message.HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
message.additionalHeaders.forEach { key, value -> putStringProperty(key, value) }
}
}
}

View File

@ -0,0 +1,108 @@
package net.corda.node.services.messaging
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import java.io.Serializable
import java.time.Instant
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
/**
* Encapsulate the de-duplication logic.
*/
class P2PMessageDeduplicator(private val database: CordaPersistence) {
val ourSenderUUID = UUID.randomUUID().toString()
// A temporary in-memory set of deduplication IDs and associated high water mark details.
// When we receive a message we don't persist the ID immediately,
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
// redeliver messages to the same consumer if they weren't ACKed.
private val beingProcessedMessages = ConcurrentHashMap<DeduplicationId, MessageMeta>()
private val processedMessages = createProcessedMessages()
private fun createProcessedMessages(): AppendOnlyPersistentMap<DeduplicationId, MessageMeta, ProcessedMessage, String> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { it.toString },
fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) },
toPersistentEntity = { key: DeduplicationId, value: MessageMeta ->
ProcessedMessage().apply {
id = key.toString
insertionTime = value.insertionTime
hash = value.senderHash
seqNo = value.senderSeqNo
}
},
persistentEntityClass = ProcessedMessage::class.java
)
}
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages }
// We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache.
private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString()
/**
* @return true if we have seen this message before.
*/
fun isDuplicate(msg: ReceivedMessage): Boolean {
if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) {
return true
}
return isDuplicateInDatabase(msg)
}
/**
* Called the first time we encounter [deduplicationId].
*/
fun signalMessageProcessStart(msg: ReceivedMessage) {
val receivedSenderUUID = msg.senderUUID
val receivedSenderSeqNo = msg.senderSeqNo
// We don't want a mix of nulls and values so we ensure that here.
val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null
val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo)
}
/**
* Called inside a DB transaction to persist [deduplicationId].
*/
fun persistDeduplicationId(deduplicationId: DeduplicationId) {
processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!!
}
/**
* Called after the DB transaction persisting [deduplicationId] committed.
* Any subsequent redelivery will be deduplicated using the DB.
*/
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) {
beingProcessedMessages.remove(deduplicationId)
}
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage(
@Id
@Column(name = "message_id", length = 64)
var id: String = "",
@Column(name = "insertion_time")
var insertionTime: Instant = Instant.now(),
@Column(name = "sender", length = 64)
var hash: String? = "",
@Column(name = "sequence_number")
var seqNo: Long? = null
) : Serializable
private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?)
private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean)
}

View File

@ -1,8 +1,11 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
@ -21,9 +24,8 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer
import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex
import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.StateMachineManagerImpl
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.node.utilities.PersistentMap
import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.ConnectionDirection
@ -38,7 +40,8 @@ import net.corda.nodeapi.internal.bridging.BridgeEntry
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.*
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
@ -50,15 +53,16 @@ import java.io.Serializable
import java.security.PublicKey
import java.time.Instant
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.ThreadSafe
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Lob
// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox
/**
* This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product.
* Artemis is a message queue broker and here we run a client connecting to the specified broker instance
@ -85,7 +89,7 @@ import javax.persistence.Lob
* @param maxMessageSize A bound applied to the message size.
*/
@ThreadSafe
class P2PMessagingClient(private val config: NodeConfiguration,
class P2PMessagingClient(val config: NodeConfiguration,
private val versionInfo: VersionInfo,
private val serverAddress: NetworkHostAndPort,
private val myIdentity: PublicKey,
@ -97,26 +101,11 @@ class P2PMessagingClient(private val config: NodeConfiguration,
private val maxMessageSize: Int,
private val isDrainingModeOn: () -> Boolean,
private val drainingModeWasChangedEvents: Observable<Pair<Boolean, Boolean>>
) : SingletonSerializeAsToken(), MessagingService, AutoCloseable {
) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver, AutoCloseable {
companion object {
private val log = contextLogger()
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
private const val messageMaxRetryCount: Int = 3
fun createProcessedMessage(): AppendOnlyPersistentMap<String, Instant, ProcessedMessage, String> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { it },
fromPersistentEntity = { Pair(it.uuid, it.insertionTime) },
toPersistentEntity = { key: String, value: Instant ->
ProcessedMessage().apply {
uuid = key
insertionTime = value
}
},
persistentEntityClass = ProcessedMessage::class.java
)
}
fun createMessageToRedeliver(): PersistentMap<Long, Pair<Message, MessageRecipients>, RetryMessage, Long> {
return PersistentMap(
toPersistentEntityKey = { it },
@ -137,7 +126,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
)
}
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message {
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topic#${String(data.bytes)}"
}
@ -165,32 +154,17 @@ class P2PMessagingClient(private val config: NodeConfiguration,
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
/** A registration to handle messages of different types */
data class Handler(val topic: String,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
private val cordaVendor = SimpleString(versionInfo.vendor)
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
/** An executor for sending messages */
private val messagingExecutor = AffinityExecutor.ServiceAffinityExecutor("Messaging ${myIdentity.toStringShort()}", 1)
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong()
private val state = ThreadBox(InnerState())
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
private val handlers = CopyOnWriteArrayList<Handler>()
private val processedMessages = createProcessedMessage()
private val handlers = ConcurrentHashMap<String, MessageHandler>()
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage(
@Id
@Column(name = "message_id", length = 64)
var uuid: String = "",
@Column(name = "insertion_time")
var insertionTime: Instant = Instant.now()
) : Serializable
private val deduplicator = P2PMessageDeduplicator(database)
internal var messagingExecutor: MessagingExecutor? = null
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
@ -246,6 +220,14 @@ class P2PMessagingClient(private val config: NodeConfiguration,
inboxes.forEach { createQueueIfAbsent(it, producerSession!!) }
p2pConsumer = P2PMessagingConsumer(inboxes, createNewSession, isDrainingModeOn, drainingModeWasChangedEvents)
messagingExecutor = MessagingExecutor(
producerSession!!,
producer!!,
versionInfo,
this@P2PMessagingClient,
ourSenderUUID = deduplicator.ourSenderUUID
)
registerBridgeControl(bridgeSession!!, inboxes.toList())
enumerateBridges(bridgeSession!!, inboxes.toList())
}
@ -255,7 +237,9 @@ class P2PMessagingClient(private val config: NodeConfiguration,
private fun InnerState.registerBridgeControl(session: ClientSession, inboxes: List<String>) {
val bridgeNotifyQueue = "$BRIDGE_NOTIFY.${myIdentity.toStringShort()}"
session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue)
if (!session.queueQuery(SimpleString(bridgeNotifyQueue)).isExists) {
session.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue)
}
val bridgeConsumer = session.createConsumer(bridgeNotifyQueue)
bridgeNotifyConsumer = bridgeConsumer
bridgeConsumer.setMessageHandler { msg ->
@ -273,7 +257,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
networkChangeSubscription = networkMap.changed.subscribe { updateBridgesOnNetworkChange(it) }
}
private fun sendBridgeControl(message: BridgeControl) {
private fun sendBridgeControl(message: BridgeControl) {
state.locked {
val controlPacket = message.serialize(context = SerializationDefaults.P2P_CONTEXT).bytes
val artemisMessage = producerSession!!.createMessage(false)
@ -343,38 +327,35 @@ class P2PMessagingClient(private val config: NodeConfiguration,
private fun resumeMessageRedelivery() {
messagesToRedeliver.forEach { retryId, (message, target) ->
sendInternal(message, target, retryId)
send(message, target, retryId)
}
}
private val shutdownLatch = CountDownLatch(1)
var runningFuture = openFuture<Unit>()
/**
* Starts the p2p event loop: this method only returns once [stop] has been called.
*/
fun run() {
val latch = CountDownLatch(1)
try {
val consumer = state.locked {
check(started) { "start must be called first" }
check(!running) { "run can't be called twice" }
running = true
runningFuture.set(Unit)
// If it's null, it means we already called stop, so return immediately.
if (p2pConsumer == null) {
return
}
eventsSubscription = p2pConsumer!!.messages
.doOnError { error -> throw error }
.doOnNext { artemisMessage ->
val receivedMessage = artemisToCordaMessage(artemisMessage)
receivedMessage?.let {
deliver(it)
}
artemisMessage.acknowledge()
}
.doOnNext { message -> deliver(message) }
// this `run()` method is semantically meant to block until the message consumption runs, hence the latch here
.doOnCompleted(latch::countDown)
.doOnError { error -> throw error }
.subscribe()
p2pConsumer!!
}
@ -391,10 +372,13 @@ class P2PMessagingClient(private val config: NodeConfiguration,
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) }
log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid")
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) }
val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID)
val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null
val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE
log.trace { "Received message from: ${message.address} user: $user topic: $topic id: $uniqueMessageId senderUUID: $receivedSenderUUID senderSeqNo: $receivedSenderSeqNo isSessionInit: $isSessionInit" }
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message)
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uniqueMessageId, receivedSenderUUID, receivedSenderSeqNo, isSessionInit, message)
} catch (e: Exception) {
log.error("Unable to process message, ignoring it: $message", e)
return null
@ -409,52 +393,55 @@ class P2PMessagingClient(private val config: NodeConfiguration,
private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name,
override val platformVersion: Int,
override val uniqueMessageId: String,
override val uniqueMessageId: DeduplicationId,
override val senderUUID: String?,
override val senderSeqNo: Long?,
override val isSessionInit: Boolean,
private val message: ClientMessage) : ReceivedMessage {
override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) }
override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp)
override val additionalHeaders: Map<String, String> = emptyMap()
override fun toString() = "$topic#$data"
}
private fun deliver(msg: ReceivedMessage): Boolean {
state.checkNotLocked()
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { it.topic.isBlank() || it.topic == msg.topic }
try {
// This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will
// be slow, and Artemis can handle that case intelligently. We don't just invoke the handler
// directly in order to ensure that we have the features of the AffinityExecutor class throughout
// the bulk of the codebase and other non-messaging jobs can be scheduled onto the server executor
// easily.
//
// Note that handlers may re-enter this class. We aren't holding any locks and methods like
// start/run/stop have re-entrancy assertions at the top, so it is OK.
nodeExecutor.fetchFrom {
database.transaction {
if (msg.uniqueMessageId in processedMessages) {
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" }
} else {
if (deliverTo.isEmpty()) {
// TODO: Implement dead letter queue, and send it there.
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
} else {
callHandlers(msg, deliverTo)
}
// TODO We will at some point need to decide a trimming policy for the id's
processedMessages[msg.uniqueMessageId] = Instant.now()
}
}
internal fun deliver(artemisMessage: ClientMessage) {
artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
if (!deduplicator.isDuplicate(cordaMessage)) {
deduplicator.signalMessageProcessStart(cordaMessage)
deliver(cordaMessage, artemisMessage)
} else {
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" }
artemisMessage.individualAcknowledge()
}
} catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
return true
}
private fun callHandlers(msg: ReceivedMessage, deliverTo: List<Handler>) {
for (handler in deliverTo) {
handler.callback(msg, handler)
private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) {
state.checkNotLocked()
val deliverTo = handlers[msg.topic]
if (deliverTo != null) {
try {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg))
} catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
} else {
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
}
}
inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, val cordaMessage: ReceivedMessage) : DeduplicationHandler {
override fun insideDatabaseTransaction() {
deduplicator.persistDeduplicationId(cordaMessage.uniqueMessageId)
}
override fun afterDatabaseTransaction() {
deduplicator.signalMessageProcessFinish(cordaMessage.uniqueMessageId)
messagingExecutor!!.acknowledge(artemisMessage)
}
override fun toString(): String {
return "${javaClass.simpleName}(${cordaMessage.uniqueMessageId})"
}
}
@ -470,6 +457,7 @@ class P2PMessagingClient(private val config: NodeConfiguration,
check(started)
val prevRunning = running
running = false
runningFuture = openFuture()
networkChangeSubscription?.unsubscribe()
require(p2pConsumer != null, { "stop can't be called twice" })
require(producer != null, { "stop can't be called twice" })
@ -507,75 +495,42 @@ class P2PMessagingClient(private val config: NodeConfiguration,
override fun close() = stop()
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
sendInternal(message, target, retryId, additionalHeaders)
}
private fun sendInternal(message: Message, target: MessageRecipients, retryId: Long?, additionalHeaders: Map<String, String> = emptyMap()) {
// We have to perform sending on a different thread pool, since using the same pool for messaging and
// fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals.
messagingExecutor.fetchFrom {
state.locked {
val mqAddress = getMQAddress(target)
val artemisMessage = producerSession!!.createMessage(true).apply {
putStringProperty(P2PMessagingHeaders.cordaVendorProperty, cordaVendor)
putStringProperty(P2PMessagingHeaders.releaseVersionProperty, releaseVersion)
putIntProperty(P2PMessagingHeaders.platformVersionProperty, versionInfo.platformVersion)
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) {
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
additionalHeaders.forEach { key, value -> putStringProperty(key, value) }
}
log.trace {
"Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}"
}
sendMessage(mqAddress, artemisMessage)
retryId?.let {
database.transaction {
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
}
scheduledMessageRedeliveries[it] = messagingExecutor.schedule({
sendWithRetry(0, mqAddress, artemisMessage, it)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
}
@Suspendable
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
messagingExecutor!!.send(message, target)
retryId?.let {
database.transaction {
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
}
scheduledMessageRedeliveries[it] = nodeExecutor.schedule({
sendWithRetry(0, message, target, retryId)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
}
}
@Suspendable
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey)
}
}
private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) {
fun ClientMessage.randomiseDuplicateId() {
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) {
log.trace { "Attempting to retry #$retryCount message delivery for $retryId" }
if (retryCount >= messageMaxRetryCount) {
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $address")
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target")
scheduledMessageRedeliveries.remove(retryId)
return
}
message.randomiseDuplicateId()
state.locked {
log.trace { "Retry #$retryCount sending message $message to $address for $retryId" }
sendMessage(address, message)
val messageWithRetryCount = object : Message by message {
override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount")
}
scheduledMessageRedeliveries[retryId] = messagingExecutor.schedule({
sendWithRetry(retryCount + 1, address, message, retryId)
messagingExecutor!!.send(messageWithRetryCount, target)
scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({
sendWithRetry(retryCount + 1, message, target, retryId)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
}
@ -590,18 +545,14 @@ class P2PMessagingClient(private val config: NodeConfiguration,
}
}
private fun Pair<ClientMessage, ReceivedMessage?>.deliver() = deliver(second!!)
private fun Pair<ClientMessage, ReceivedMessage?>.acknowledge() = first.acknowledge()
private fun getMQAddress(target: MessageRecipients): String {
return if (target == myAddress) {
override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
return if (address == myAddress) {
// If we are sending to ourselves then route the message directly to our P2P queue.
RemoteInboxAddress(myIdentity).queueName
} else {
// Otherwise we send the message to an internal queue for the target residing on our broker. It's then the
// broker's job to route the message to the target's P2P queue.
val internalTargetQueue = (target as? ArtemisAddress)?.queueName
?: throw IllegalArgumentException("Not an Artemis address")
val internalTargetQueue = (address as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address")
state.locked {
createQueueIfAbsent(internalTargetQueue, producerSession!!)
}
@ -630,24 +581,26 @@ class P2PMessagingClient(private val config: NodeConfiguration,
}
}
override fun addMessageHandler(topic: String,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(topic, callback)
handlers.add(handler)
return handler
handlers.compute(topic) { _, handler ->
if (handler != null) {
throw IllegalStateException("Cannot add another acking handler for $topic, there is already an acking one")
}
callback
}
return HandlerRegistration(topic, callback)
}
override fun removeMessageHandler(registration: MessageHandlerRegistration) {
handlers.remove(registration)
registration as HandlerRegistration
handlers.remove(registration.topic)
}
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId, deduplicator.ourSenderUUID, additionalHeaders)
}
// TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port)
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
return when (partyInfo) {
is PartyInfo.SingleNode -> NodeAddress(partyInfo.party.owningKey, partyInfo.addresses.single())
@ -720,4 +673,4 @@ private fun ReactiveArtemisConsumer.switchTo(other: ReactiveArtemisConsumer) {
!other.started -> other.start()
!other.connected -> other.connect()
}
}
}

View File

@ -1,11 +1,16 @@
package net.corda.node.services.persistence
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.api.Checkpoint
import net.corda.core.utilities.debug
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.Checkpoint
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
import org.slf4j.LoggerFactory
import java.util.*
import java.util.stream.Stream
import java.io.Serializable
import javax.persistence.Column
import javax.persistence.Entity
@ -16,6 +21,7 @@ import javax.persistence.Lob
* Simple checkpoint key value storage in DB.
*/
class DBCheckpointStorage : CheckpointStorage {
val log = LoggerFactory.getLogger(this::class.java)
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints")
@ -29,32 +35,30 @@ class DBCheckpointStorage : CheckpointStorage {
var checkpoint: ByteArray = EMPTY_BYTE_ARRAY
) : Serializable
override fun addCheckpoint(checkpoint: Checkpoint) {
currentDBSession().save(DBCheckpoint().apply {
checkpointId = checkpoint.id.toString()
this.checkpoint = checkpoint.serializedFiber.bytes // XXX: Is copying the byte array necessary?
override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) {
currentDBSession().saveOrUpdate(DBCheckpoint().apply {
checkpointId = id.uuid.toString()
this.checkpoint = checkpoint.bytes
log.debug { "Checkpoint $checkpointId, size=${this.checkpoint.size}" }
})
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
override fun removeCheckpoint(id: StateMachineRunId): Boolean {
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java)
val root = delete.from(DBCheckpoint::class.java)
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), checkpoint.id.toString()))
session.createQuery(delete).executeUpdate()
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), id.uuid.toString()))
return session.createQuery(delete).executeUpdate() > 0
}
override fun forEach(block: (Checkpoint) -> Boolean) {
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> {
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
val root = criteriaQuery.from(DBCheckpoint::class.java)
criteriaQuery.select(root)
for (row in session.createQuery(criteriaQuery).resultList) {
val checkpoint = Checkpoint(SerializedBytes(row.checkpoint))
if (!block(checkpoint)) {
break
}
return session.createQuery(criteriaQuery).stream().map {
StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes<Checkpoint>(it.checkpoint)
}
}
}

View File

@ -1,15 +1,19 @@
package net.corda.node.services.persistence
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.toFuture
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.utilities.*
import net.corda.node.utilities.AppendOnlyPersistentMapBase
import net.corda.node.utilities.WeightBasedAppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
@ -92,7 +96,18 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return txStorage.locked {
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updates.bufferUntilSubscribed().wrapWithDatabaseTransaction())
}
}
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return txStorage.locked {
val existingTransaction = get(id)
if (existingTransaction == null) {
updates.filter { it.id == id }.toFuture()
} else {
doneFuture(existingTransaction.toSignedTx())
}
}
}

View File

@ -14,6 +14,7 @@ import net.corda.node.services.api.SchemaService.SchemaOptions
import net.corda.node.services.events.NodeSchedulerService
import net.corda.node.services.identity.PersistentIdentityService
import net.corda.node.services.keys.PersistentKeyManagementService
import net.corda.node.services.messaging.P2PMessageDeduplicator
import net.corda.node.services.messaging.P2PMessagingClient
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.DBTransactionMappingStorage
@ -43,7 +44,7 @@ class NodeSchemaService(extraSchemas: Set<MappedSchema> = emptySet(), includeNot
PersistentKeyManagementService.PersistentKey::class.java,
NodeSchedulerService.PersistentScheduledState::class.java,
NodeAttachmentService.DBAttachment::class.java,
P2PMessagingClient.ProcessedMessage::class.java,
P2PMessageDeduplicator.ProcessedMessage::class.java,
P2PMessagingClient.RetryMessage::class.java,
PersistentIdentityService.PersistentIdentity::class.java,
PersistentIdentityService.PersistentIdentityNames::class.java,

View File

@ -0,0 +1,138 @@
package net.corda.node.services.statemachine
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowAsyncOperation
import net.corda.node.services.messaging.DeduplicationHandler
import java.time.Instant
import java.util.*
/**
* [Action]s are reified IO actions to execute as part of state machine transitions.
*/
sealed class Action {
/**
* Track a transaction hash and notify the state machine once the corresponding transaction has committed.
*/
data class TrackTransaction(val hash: SecureHash) : Action()
/**
* Send an initial session message to [party].
*/
data class SendInitial(
val party: Party,
val initialise: InitialSessionMessage,
val deduplicationId: DeduplicationId
) : Action()
/**
* Send a session message to a [peerParty] with which we have an established session.
*/
data class SendExisting(
val peerParty: Party,
val message: ExistingSessionMessage,
val deduplicationId: DeduplicationId
) : Action()
/**
* Persist the specified [checkpoint].
*/
data class PersistCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) : Action()
/**
* Remove the checkpoint corresponding to [id].
*/
data class RemoveCheckpoint(val id: StateMachineRunId) : Action()
/**
* Persist the deduplication facts of [deduplicationHandlers].
*/
data class PersistDeduplicationFacts(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/**
* Acknowledge messages in [deduplicationHandlers].
*/
data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/**
* Propagate [errorMessages] to [sessions].
* @param sessions a map from source session IDs to initiated sessions.
*/
data class PropagateErrors(
val errorMessages: List<ErrorSessionMessage>,
val sessions: List<SessionState.Initiated>
) : Action()
/**
* Create a session binding from [sessionId] to [flowId] to allow routing of incoming messages.
*/
data class AddSessionBinding(val flowId: StateMachineRunId, val sessionId: SessionId) : Action()
/**
* Remove the session bindings corresponding to [sessionIds].
*/
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
/**
* Signal that the flow corresponding to [flowId] is considered started.
*/
data class SignalFlowHasStarted(val flowId: StateMachineRunId) : Action()
/**
* Remove the flow corresponding to [flowId].
*/
data class RemoveFlow(
val flowId: StateMachineRunId,
val removalReason: FlowRemovalReason,
val lastState: StateMachineState
) : Action()
/**
* Schedule [event] to self.
*/
data class ScheduleEvent(val event: Event) : Action()
/**
* Sleep until [time].
*/
data class SleepUntil(val time: Instant) : Action()
/**
* Create a new database transaction.
*/
object CreateTransaction : Action() { override fun toString() = "CreateTransaction" }
/**
* Roll back the current database transaction.
*/
object RollbackTransaction : Action() { override fun toString() = "RollbackTransaction" }
/**
* Commit the current database transaction.
*/
object CommitTransaction : Action() { override fun toString() = "CommitTransaction" }
/**
* Execute the specified [operation].
*/
data class ExecuteAsyncOperation(val operation: FlowAsyncOperation<*>) : Action()
/**
* Release soft locks associated with given ID (currently the flow ID).
*/
data class ReleaseSoftLocks(val uuid: UUID?) : Action()
}
/**
* Reason for flow removal.
*/
sealed class FlowRemovalReason {
data class OrderlyFinish(val flowReturnValue: Any?) : FlowRemovalReason()
data class ErrorFinish(val flowErrors: List<FlowError>) : FlowRemovalReason()
object SoftShutdown : FlowRemovalReason() { override fun toString() = "SoftShutdown" }
// TODO Should we remove errored flows? How will the flow hospital work? Perhaps keep them in memory for a while, flush
// them after a timeout, reload them on flow hospital request. In any case if we ever want to remove them
// (e.g. temporarily) then add a case for that here.
}

View File

@ -0,0 +1,14 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
/**
* An executor of a single [Action].
*/
interface ActionExecutor {
/**
* Execute [action] by [fiber].
*/
@Suspendable
fun executeAction(fiber: FlowFiber, action: Action)
}

View File

@ -0,0 +1,233 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.*
import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import java.time.Duration
import java.time.Instant
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
/**
* This is the bottom execution engine of flow side-effects.
*/
class ActionExecutorImpl(
private val services: ServiceHubInternal,
private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext,
metrics: MetricRegistry
) : ActionExecutor {
private companion object {
val log = contextLogger()
}
/**
* This [Gauge] just reports the sum of the bytes checkpointed during the last second.
*/
private class LatchedGauge(private val reservoir: Reservoir) : Gauge<Long> {
override fun getValue(): Long {
return reservoir.snapshot.values.sum()
}
}
private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate")
private val checkpointSizesThisSecond = SlidingTimeWindowReservoir(1, TimeUnit.SECONDS)
private val lastBandwidthUpdate = AtomicLong(0)
private val checkpointBandwidthHist = metrics.register("Flows.CheckpointVolumeBytesPerSecondHist", Histogram(SlidingTimeWindowArrayReservoir(1, TimeUnit.DAYS)))
private val checkpointBandwidth = metrics.register("Flows.CheckpointVolumeBytesPerSecondCurrent", LatchedGauge(checkpointSizesThisSecond))
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
log.trace { "Flow ${fiber.id} executing $action" }
return when (action) {
is Action.TrackTransaction -> executeTrackTransaction(fiber, action)
is Action.PersistCheckpoint -> executePersistCheckpoint(action)
is Action.PersistDeduplicationFacts -> executePersistDeduplicationIds(action)
is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action)
is Action.PropagateErrors -> executePropagateErrors(action)
is Action.ScheduleEvent -> executeScheduleEvent(fiber, action)
is Action.SleepUntil -> executeSleepUntil(action)
is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action)
is Action.SendInitial -> executeSendInitial(action)
is Action.SendExisting -> executeSendExisting(action)
is Action.AddSessionBinding -> executeAddSessionBinding(action)
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
is Action.RemoveFlow -> executeRemoveFlow(action)
is Action.CreateTransaction -> executeCreateTransaction()
is Action.RollbackTransaction -> executeRollbackTransaction()
is Action.CommitTransaction -> executeCommitTransaction()
is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action)
is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action)
}
}
private fun executeReleaseSoftLocks(action: Action.ReleaseSoftLocks) {
if (action.uuid != null) services.vaultService.softLockRelease(action.uuid)
}
@Suspendable
private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) {
services.validatedTransactions.trackTransaction(action.hash).thenMatch(
success = { transaction ->
fiber.scheduleEvent(Event.TransactionCommitted(transaction))
},
failure = { exception ->
fiber.scheduleEvent(Event.Error(exception))
}
)
}
@Suspendable
private fun executePersistCheckpoint(action: Action.PersistCheckpoint) {
val checkpointBytes = serializeCheckpoint(action.checkpoint)
checkpointStorage.addCheckpoint(action.id, checkpointBytes)
checkpointingMeter.mark()
checkpointSizesThisSecond.update(checkpointBytes.size.toLong())
var lastUpdateTime = lastBandwidthUpdate.get()
while (System.nanoTime() - lastUpdateTime > TimeUnit.SECONDS.toNanos(1)) {
if (lastBandwidthUpdate.compareAndSet(lastUpdateTime, System.nanoTime())) {
val checkpointVolume = checkpointSizesThisSecond.snapshot.values.sum()
checkpointBandwidthHist.update(checkpointVolume)
}
lastUpdateTime = lastBandwidthUpdate.get()
}
}
@Suspendable
private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationFacts) {
for (handle in action.deduplicationHandlers) {
handle.insideDatabaseTransaction()
}
}
@Suspendable
private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) {
action.deduplicationHandlers.forEach {
it.afterDatabaseTransaction()
}
}
@Suspendable
private fun executePropagateErrors(action: Action.PropagateErrors) {
action.errorMessages.forEach { error ->
val exception = error.flowException
log.debug("Propagating error", exception)
}
for (sessionState in action.sessions) {
// We cannot propagate if the session isn't live.
if (sessionState.initiatedState !is InitiatedSessionState.Live) {
continue
}
// Don't propagate errors to the originating session
for (errorMessage in action.errorMessages) {
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId)
}
}
}
@Suspendable
private fun executeScheduleEvent(fiber: FlowFiber, action: Action.ScheduleEvent) {
fiber.scheduleEvent(action.event)
}
@Suspendable
private fun executeSleepUntil(action: Action.SleepUntil) {
// TODO introduce explicit sleep state + wakeup event instead of relying on Fiber.sleep. This is so shutdown
// conditions may "interrupt" the sleep instead of waiting until wakeup.
val duration = Duration.between(Instant.now(), action.time)
Fiber.sleep(duration.toNanos(), TimeUnit.NANOSECONDS)
}
@Suspendable
private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) {
checkpointStorage.removeCheckpoint(action.id)
}
@Suspendable
private fun executeSendInitial(action: Action.SendInitial) {
flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId)
}
@Suspendable
private fun executeSendExisting(action: Action.SendExisting) {
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId)
}
@Suspendable
private fun executeAddSessionBinding(action: Action.AddSessionBinding) {
stateMachineManager.addSessionBinding(action.flowId, action.sessionId)
}
@Suspendable
private fun executeRemoveSessionBindings(action: Action.RemoveSessionBindings) {
stateMachineManager.removeSessionBindings(action.sessionIds)
}
@Suspendable
private fun executeSignalFlowHasStarted(action: Action.SignalFlowHasStarted) {
stateMachineManager.signalFlowHasStarted(action.flowId)
}
@Suspendable
private fun executeRemoveFlow(action: Action.RemoveFlow) {
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)
}
@Suspendable
private fun executeCreateTransaction() {
if (contextTransactionOrNull != null) {
throw IllegalStateException("Refusing to create a second transaction")
}
contextDatabase.newTransaction()
}
@Suspendable
private fun executeRollbackTransaction() {
contextTransactionOrNull?.close()
}
@Suspendable
private fun executeCommitTransaction() {
try {
contextTransaction.commit()
} finally {
contextTransaction.close()
contextTransactionOrNull = null
}
}
@Suspendable
private fun executeAsyncOperation(fiber: FlowFiber, action: Action.ExecuteAsyncOperation) {
val operationFuture = action.operation.execute()
operationFuture.thenMatch(
success = { result ->
fiber.scheduleEvent(Event.AsyncOperationCompletion(result))
},
failure = { exception ->
fiber.scheduleEvent(Event.Error(exception))
}
)
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
}
}

View File

@ -0,0 +1,66 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.strands.concurrent.AbstractQueuedSynchronizer
import co.paralleluniverse.fibers.Suspendable
/**
* Quasar-compatible latch that may be incremented.
*/
class CountUpDownLatch(initialValue: Int) {
// See quasar CountDownLatch
private class Sync(initialValue: Int) : AbstractQueuedSynchronizer() {
init {
state = initialValue
}
override fun tryAcquireShared(arg: Int): Int {
if (arg >= 0) {
return if (state == arg) 1 else -1
} else {
return if (state <= -arg) 1 else -1
}
}
override fun tryReleaseShared(arg: Int): Boolean {
while (true) {
val c = state
if (c == 0)
return false
val nextc = c - Math.min(c, arg)
if (compareAndSetState(c, nextc))
return nextc == 0
}
}
fun increment() {
while (true) {
val c = state
val nextc = c + 1
if (compareAndSetState(c, nextc))
return
}
}
}
private val sync = Sync(initialValue)
@Suspendable
fun await() {
sync.acquireSharedInterruptibly(0)
}
@Suspendable
fun awaitLessThanOrEqual(number: Int) {
sync.acquireSharedInterruptibly(number)
}
fun countDown(number: Int = 1) {
require(number > 0)
sync.releaseShared(number)
}
fun countUp() {
sync.increment()
}
}

View File

@ -0,0 +1,47 @@
package net.corda.node.services.statemachine
import java.security.SecureRandom
/**
* A deduplication ID of a flow message.
*/
data class DeduplicationId(val toString: String) {
companion object {
/**
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
* unless we persist the ID somehow.
*/
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
/**
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
* creating IDs in case the message-generating flow logic is replayed on hard failure.
*
* A normal deduplication ID consists of:
* 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the
* initiator's session ID.
* 2. The number of *clean* suspends since the start of the flow.
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
* message-id map to change, which means deduplication will not happen correctly.
*/
fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId {
return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index")
}
/**
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
* dirtiness state to be thrown away for resumption.
*
* An error deduplication ID consists of:
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
* See [net.corda.core.flows.IdentifiableException].
* 2. The recipient's session ID.
*/
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
}
}
}

View File

@ -0,0 +1,128 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.messaging.DeduplicationHandler
import java.util.*
/**
* Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from
* outside (e.g. in case of message delivery or external event).
*/
sealed class Event {
/**
* Check the current state for pending work. For example if the flow is waiting for a message from a particular
* session this event may cause a flow resume if we have a corresponding message. In general the state machine
* should be idempotent in the [DoRemainingWork] event, meaning a second subsequent event shouldn't modify the state
* or produce [Action]s.
*/
object DoRemainingWork : Event() { override fun toString() = "DoRemainingWork" }
/**
* Deliver a session message.
* @param sessionMessage the message itself.
* @param deduplicationHandler the handle to acknowledge the message after checkpointing.
* @param sender the sender [Party].
*/
data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage,
val deduplicationHandler: DeduplicationHandler,
val sender: Party
) : Event()
/**
* Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error.
* @param exception the exception itself.
*/
data class Error(val exception: Throwable) : Event()
/**
* Signal that a ledger transaction has committed. This is an event completing a [FlowIORequest.WaitForLedgerCommit]
* suspension.
* @param transaction the transaction that was committed.
*/
data class TransactionCommitted(val transaction: SignedTransaction) : Event()
/**
* Trigger a soft shutdown, removing the flow as soon as possible. This causes the flow to be removed as soon as
* this event is processed. Note that on restart the flow will resume as normal.
*/
object SoftShutdown : Event() { override fun toString() = "SoftShutdown" }
/**
* Start error propagation on a errored flow. This may be triggered by e.g. a [FlowHospital].
*/
object StartErrorPropagation : Event() { override fun toString() = "StartErrorPropagation" }
/**
*
* Scheduled by the flow.
*
* Initiate a flow. This causes a new session object to be created and returned to the flow. Note that no actual
* communication takes place at this time, only on the first send/receive operation on the session.
* @param party the [Party] to create a session with.
*/
data class InitiateFlow(val party: Party) : Event()
/**
* Signal the entering into a subflow.
*
* Scheduled and executed by the flow.
*
* @param subFlowClass the [Class] of the subflow, to be used to determine whether it's Initiating or inlined.
*/
data class EnterSubFlow(val subFlowClass: Class<FlowLogic<*>>) : Event()
/**
* Signal the leaving of a subflow.
*
* Scheduled by the flow.
*
*/
object LeaveSubFlow : Event() { override fun toString() = "LeaveSubFlow" }
/**
* Signal a flow suspension. This causes the flow's stack and the state machine's state together with the suspending
* IO request to be persisted into the database.
*
* Scheduled by the flow and executed inside the park closure.
*
* @param ioRequest the request triggering the suspension.
* @param maySkipCheckpoint indicates whether the persistence may be skipped.
* @param fiber the serialised stack of the flow.
*/
data class Suspend(
val ioRequest: FlowIORequest<*>,
val maySkipCheckpoint: Boolean,
val fiber: SerializedBytes<FlowStateMachineImpl<*>>
) : Event() {
override fun toString() =
"Suspend(" +
"ioRequest=$ioRequest, " +
"maySkipCheckpoint=$maySkipCheckpoint, " +
"fiber=${fiber.hash}, " +
")"
}
/**
* Signals clean flow finish.
*
* Scheduled by the flow.
*
* @param returnValue the return value of the flow.
* @param softLocksId the flow ID of the flow if it is holding soft locks, else null.
*/
data class FlowFinish(val returnValue: Any?, val softLocksId: UUID?) : Event()
/**
* Signals the completion of a [FlowAsyncOperation].
*
* Scheduling is triggered by the service that completes the future returned by the async operation.
*
* @param returnValue the result of the operation.
*/
data class AsyncOperationCompletion(val returnValue: Any?) : Event()
}

View File

@ -0,0 +1,18 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.transitions.StateMachine
/**
* An interface wrapping a fiber running a flow.
*/
interface FlowFiber {
val id: StateMachineRunId
val stateMachine: StateMachine
@Suspendable
fun scheduleEvent(event: Event)
fun snapshot(): StateMachineState
}

View File

@ -0,0 +1,18 @@
package net.corda.node.services.statemachine
/**
* A flow hospital is a class that is notified when a flow transitions into an error state due to an uncaught exception
* or internal error condition, and when it becomes clean again (e.g. due to a resume).
* Also see [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor].
*/
interface FlowHospital {
/**
* The flow running in [flowFiber] has errored.
*/
fun flowErrored(flowFiber: FlowFiber)
/**
* The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume.
*/
fun flowCleaned(flowFiber: FlowFiber)
}

View File

@ -1,121 +0,0 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import java.time.Instant
interface FlowIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot
}
interface WaitingRequest : FlowIORequest {
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
}
interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSessionInternal
}
interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage
}
interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest {
val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
}
data class SendAndReceive(override val session: FlowSessionInternal,
override val message: SessionMessage,
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly(
override val session: FlowSessionInternal,
override val userReceiveType: Class<*>?
) : ReceiveRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class ReceiveAll(val requests: List<ReceiveRequest>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
return received.keys == requests.map { it.session }.toSet()
}
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean {
return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage }
}
@Suspendable
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
poll(receivedMessages)
return if (isComplete(receivedMessages)) {
receivedMessages
} else {
suspend(this)
poll(receivedMessages)
if (isComplete(receivedMessages)) {
receivedMessages
} else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" })
}
}
}
interface Suspend {
@Suspendable
operator fun invoke(request: FlowIORequest)
}
@Suspendable
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
poll(request)?.let {
receivedMessages[request.session] = RequestMessage(request, it)
}
}
}
@Suspendable
private fun poll(request: ReceiveRequest): ExistingSessionMessage? {
return request.session.receivedMessages.poll()?.message
}
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage)
}
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage
}
data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -156,4 +156,4 @@ class FlowLogicRefFactoryImpl(private val classloader: ClassLoader) : SingletonS
return false
}
}
}
}

View File

@ -0,0 +1,89 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import com.esotericsoftware.kryo.KryoException
import net.corda.core.context.InvocationOrigin
import net.corda.core.flows.FlowException
import net.corda.core.identity.Party
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import java.io.NotSerializableException
/**
* A wrapper interface around flow messaging.
*/
interface FlowMessaging {
/**
* Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to
* listen on the send acknowledgement.
*/
@Suspendable
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId)
/**
* Start the messaging using the [onMessage] message handler.
*/
fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
}
/**
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
*/
class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
companion object {
val log = contextLogger()
val sessionTopic = "platform.session"
}
override fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) {
serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, deduplicationHandler ->
onMessage(receivedMessage, deduplicationHandler)
}
}
@Suspendable
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) {
log.trace { "Sending message $deduplicationId $message to party $party" }
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val sequenceKey = when (message) {
is InitialSessionMessage -> message.initiatorSessionId
is ExistingSessionMessage -> message.recipientSessionId
}
serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey)
}
private fun SessionMessage.additionalHeaders(target: Party): Map<String, String> {
// This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator.
// It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case.
val mightDeadlockDrainingTarget = FlowStateMachineImpl.currentStateMachine()?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name }
return when {
this !is InitialSessionMessage || mightDeadlockDrainingTarget -> emptyMap()
else -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE)
}
}
private fun serializeSessionMessage(message: SessionMessage): SerializedBytes<SessionMessage> {
return try {
message.serialize()
} catch (exception: Exception) {
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
if ((exception is KryoException || exception is NotSerializableException)
&& message is ExistingSessionMessage && message.payload is ErrorSessionMessage) {
val error = message.payload.flowException
val rewrappedError = FlowException(error?.message)
message.copy(payload = message.payload.copy(flowException = rewrappedError)).serialize()
} else {
throw exception
}
}
}
}

View File

@ -1,20 +1,40 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.checkPayloadIs
class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
internal lateinit var stateMachine: FlowStateMachine<*>
internal lateinit var sessionFlow: FlowLogic<*>
class FlowSessionImpl(
override val counterparty: Party,
val sourceSessionId: SessionId
) : FlowSession() {
override fun toString() = "FlowSessionImpl(counterparty=$counterparty, sourceSessionId=$sourceSessionId)"
override fun equals(other: Any?): Boolean {
return (other as? FlowSessionImpl)?.sourceSessionId == sourceSessionId
}
override fun hashCode() = sourceSessionId.hashCode()
private fun getFlowStateMachine(): FlowStateMachine<*> {
return Fiber.currentFiber() as FlowStateMachine<*>
}
@Suspendable
override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo {
return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint)
val request = FlowIORequest.GetFlowInfo(NonEmptySet.of(this))
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!
}
@Suspendable
@ -26,14 +46,15 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
payload: Any,
maySkipCheckpoint: Boolean
): UntrustworthyData<R> {
return stateMachine.sendAndReceive(
receiveType,
counterparty,
payload,
sessionFlow,
retrySend = false,
maySkipCheckpoint = maySkipCheckpoint
enforceNotPrimitive(receiveType)
val request = FlowIORequest.SendAndReceive(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = false
)
val responseValues: Map<FlowSession, SerializedBytes<Any>> = getFlowStateMachine().suspend(request, maySkipCheckpoint)
val responseForCurrentSession = responseValues[this]!!
return responseForCurrentSession.checkPayloadIs(receiveType)
}
@Suspendable
@ -41,7 +62,9 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
@Suspendable
override fun <R : Any> receive(receiveType: Class<R>, maySkipCheckpoint: Boolean): UntrustworthyData<R> {
return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint)
enforceNotPrimitive(receiveType)
val request = FlowIORequest.Receive(NonEmptySet.of(this))
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType)
}
@Suspendable
@ -49,12 +72,17 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
@Suspendable
override fun send(payload: Any, maySkipCheckpoint: Boolean) {
return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint)
val request = FlowIORequest.Send(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = false
)
return getFlowStateMachine().suspend(request, maySkipCheckpoint)
}
@Suspendable
override fun send(payload: Any) = send(payload, maySkipCheckpoint = false)
override fun toString() = "Flow session with $counterparty"
private fun enforceNotPrimitive(type: Class<*>) {
require(!type.isPrimitive) { "Cannot receive primitive type $type" }
}
}

View File

@ -1,66 +0,0 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party
import net.corda.node.services.statemachine.FlowSessionState.Initiated
import net.corda.node.services.statemachine.FlowSessionState.Initiating
import java.util.concurrent.ConcurrentLinkedQueue
/**
* @param retryable Indicates that the session initialisation should be retried until an expected [SessionData] response
* is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow
* that only sends back a single [SessionData] message before termination.
*/
// TODO rename this
class FlowSessionInternal(
val flow: FlowLogic<*>,
val flowSession : FlowSession,
val ourSessionId: SessionId,
val initiatingParty: Party?,
var state: FlowSessionState,
var retryable: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
override fun toString(): String {
return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)"
}
fun getPeerSessionId(): SessionId {
val sessionState = state
return when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this")
}
}
}
data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage)
/**
* [FlowSessionState] describes the session's state.
*
* [Uninitiated] is pre-handshake, where no communication has happened. [Initiating.otherParty] at this point holds a
* [Party] corresponding to either a specific peer or a service.
* [Initiating] is pre-handshake, where the initiating message has been sent.
* [Initiated] is post-handshake. At this point [Initiating.otherParty] will have been resolved to a specific peer
* [Initiated.peerParty], and the peer's sessionId has been initialised.
*/
sealed class FlowSessionState {
abstract val sendToParty: Party
data class Uninitiated(val otherParty: Party) : FlowSessionState() {
override val sendToParty: Party get() = otherParty
}
/** [otherParty] may be a specific peer or a service party */
data class Initiating(val otherParty: Party) : FlowSessionState() {
override val sendToParty: Party get() = otherParty
}
data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() {
override val sendToParty: Party get() = peerParty
}
}

View File

@ -5,262 +5,251 @@ import co.paralleluniverse.fibers.Fiber.parkAndSerialize
import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.primitives.Primitives
import co.paralleluniverse.strands.channels.Channel
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.*
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
import net.corda.node.services.api.FlowAppAuditEvent
import net.corda.node.services.api.FlowPermissionAuditEvent
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.logging.pushToLoggingContext
import net.corda.node.services.statemachine.FlowSessionState.Initiating
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.IOException
import java.sql.SQLException
import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty1
class FlowPermissionException(message: String) : FlowException(message)
class TransientReference<out A>(@Transient val value: A)
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
override val logic: FlowLogic<R>,
scheduler: FiberScheduler,
val ourIdentity: Party,
override val context: InvocationContext) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
scheduler: FiberScheduler
) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R>, FlowFiber {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = Fiber::class.staticField<Any>("SERIALIZER_BLOCKER").value
/**
* Return the current [FlowStateMachineImpl] or null if executing outside of one.
*/
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
}
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient override lateinit var serviceHub: ServiceHubInternal
@Transient override lateinit var ourIdentityAndCert: PartyAndCertificate
@Transient internal lateinit var database: CordaPersistence
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: (Try<R>, Boolean) -> Unit
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var txTrampoline: DatabaseTransaction? = null
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
data class TransientValues(
val eventQueue: Channel<Event>,
val resultFuture: CordaFuture<Any?>,
val database: CordaPersistence,
val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext
)
internal var transientValues: TransientReference<TransientValues>? = null
internal var transientState: TransientReference<StateMachineState>? = null
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
return field.get(suppliedValues.value)
}
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
/**
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary.
*/
override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id")
@Transient private var resultFutureTransient: OpenFuture<R>? = openFuture()
private val _resultFuture get() = resultFutureTransient ?: openFuture<R>().also { resultFutureTransient = it }
/** This future will complete when the call method returns. */
override val resultFuture: CordaFuture<R> get() = _resultFuture
// This state IS serialised, as we need it to know what the fiber is waiting for.
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSessionInternal>()
internal var waitingForResponse: WaitingRequest? = null
override val logger = log
override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture))
override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext
override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity
internal var hasSoftLockedStates: Boolean = false
set(value) {
if (value) field = value else throw IllegalArgumentException("Can only set to true")
}
init {
logic.stateMachine = this
/**
* Processes an event by creating the associated transition and executing it using the given executor.
* Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately]
* instead.
*/
@Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
val stateMachine = getTransientField(TransientValues::stateMachine)
val oldState = transientState!!.value
val actionExecutor = getTransientField(TransientValues::actionExecutor)
val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
transientState = TransientReference(newState)
return continuation
}
/**
* Processes the events in the event queue until a transition indicates that control should be returned to user code
* in the form of a regular resume or a throw of an exception. Alternatively the transition may abort the fiber
* completely.
*
* @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the
* processing of the eventloop. Purely used for internal invariant checks.
* @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the eventloop
* processing finished. Purely used for internal invariant checks.
*/
@Suspendable
private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? {
checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val eventQueue = getTransientField(TransientValues::eventQueue)
try {
eventLoop@while (true) {
val nextEvent = eventQueue.receive()
val continuation = processEvent(transitionExecutor, nextEvent)
when (continuation) {
is FlowContinuation.Resume -> return continuation.result
is FlowContinuation.Throw -> {
continuation.throwable.fillInStackTrace()
throw continuation.throwable
}
FlowContinuation.ProcessEvents -> continue@eventLoop
FlowContinuation.Abort -> abortFiber()
}
}
} finally {
checkDbTransaction(isDbTransactionOpenOnExit)
}
}
/**
* Immediately processes the passed in event. Always called with an open database transaction.
*
* @param event the event to be processed.
* @param isDbTransactionOpenOnEntry indicates whether a DB transaction is expected to be present before the
* processing of the event. Purely used for internal invariant checks.
* @param isDbTransactionOpenOnExit indicates whether a DB transaction is expected to be present once the event
* processing finished. Purely used for internal invariant checks.
*/
@Suspendable
private fun processEventImmediately(event: Event, isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): FlowContinuation {
checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val continuation = processEvent(transitionExecutor, event)
checkDbTransaction(isDbTransactionOpenOnExit)
return continuation
}
private fun checkDbTransaction(isPresent: Boolean) {
if (isPresent) {
requireNotNull(contextTransactionOrNull != null)
} else {
require(contextTransactionOrNull == null)
}
}
@Suspendable
override fun run() {
createTransaction()
logic.stateMachine = this
context.pushToLoggingContext()
initialiseFlow()
logger.debug { "Calling flow: $logic" }
val startTime = System.nanoTime()
val result = try {
val r = logic.call()
// Only sessions which have done a single send and nothing else will block here
openSessions.values
.filter { it.state is Initiating }
.forEach { it.waitForConfirmation() }
r
} catch (e: FlowException) {
recordDuration(startTime, success = false)
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
val propagated = e.stackTrace[0].className == javaClass.name
processException(e, propagated)
logger.warn(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e)
return
} catch (t: Throwable) {
recordDuration(startTime, success = false)
logger.warn("Terminated by unexpected exception", t)
processException(t, false)
return
val resultOrError = try {
val result = logic.call()
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true)
Try.Success(result)
} catch (throwable: Throwable) {
logger.warn("Flow threw exception", throwable)
Try.Failure<R>(throwable)
}
val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null
val finalEvent = when (resultOrError) {
is Try.Success -> {
Event.FlowFinish(resultOrError.value, softLocksId)
}
is Try.Failure -> {
Event.Error(resultOrError.exception)
}
}
// Immediately process the last event. This is to make sure the transition can assume that it has an open
// database transaction.
val continuation = processEventImmediately(
finalEvent,
isDbTransactionOpenOnEntry = true,
isDbTransactionOpenOnExit = false
)
if (continuation == FlowContinuation.ProcessEvents) {
// This can happen in case there was an error and there are further things to do e.g. to propagate it.
processEventsUntilFlowIsResumed(
isDbTransactionOpenOnEntry = false,
isDbTransactionOpenOnExit = false
)
}
recordDuration(startTime)
// This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd(Try.Success(result), false)
_resultFuture.set(result)
logic.progressTracker?.currentStep = ProgressTracker.DONE
logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" }
}
private fun createTransaction() {
// Make sure we have a database transaction
database.createTransaction()
logger.trace { "Starting database transaction $contextTransactionOrNull on ${Strand.currentStrand()}" }
}
private fun processException(exception: Throwable, propagated: Boolean) {
actionOnEnd(Try.Failure(exception), propagated)
_resultFuture.setException(exception)
logic.progressTracker?.endWithError(exception)
}
internal fun commitTransaction() {
val transaction = contextTransaction
try {
logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." }
transaction.commit()
} catch (e: SQLException) {
// TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly.
logger.error("Transaction commit failed: ${e.message}", e)
System.exit(1)
} finally {
transaction.close()
}
}
@Suspendable
override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession {
val sessionKey = Pair(sessionFlow, otherParty)
if (openSessions.containsKey(sessionKey)) {
throw IllegalStateException(
"Attempted to initiateFlow() twice in the same InitiatingFlow $sessionFlow for the same party " +
"$otherParty. This isn't supported in this version of Corda. Alternatively you may " +
"initiate a new flow by calling initiateFlow() in an " +
"@${InitiatingFlow::class.java.simpleName} sub-flow."
private fun initialiseFlow() {
processEventsUntilFlowIsResumed(
isDbTransactionOpenOnEntry = false,
isDbTransactionOpenOnExit = true
)
}
@Suspendable
override fun <R> subFlow(subFlow: FlowLogic<R>): R {
processEventImmediately(
Event.EnterSubFlow(subFlow.javaClass),
isDbTransactionOpenOnEntry = true,
isDbTransactionOpenOnExit = true
)
return try {
subFlow.call()
} finally {
processEventImmediately(
Event.LeaveSubFlow,
isDbTransactionOpenOnEntry = true,
isDbTransactionOpenOnExit = true
)
}
val flowSession = FlowSessionImpl(otherParty)
createNewSession(otherParty, flowSession, sessionFlow)
flowSession.stateMachine = this
flowSession.sessionFlow = sessionFlow
return flowSession
}
@Suspendable
override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo {
val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated
return state.context
override fun initiateFlow(party: Party): FlowSession {
val resume = processEventImmediately(
Event.InitiateFlow(party),
isDbTransactionOpenOnEntry = true,
isDbTransactionOpenOnExit = true
) as FlowContinuation.Resume
return resume.result as FlowSession
}
@Suspendable
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
sessionFlow: FlowLogic<*>,
retrySend: Boolean,
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
val receivedSessionMessage: ReceivedSessionMessage = if (session == null) {
val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
// Only do a receive here as the session init has carried the payload
receiveInternal(newSession, receiveType)
} else {
val sendData = createSessionData(session, payload)
sendAndReceiveInternal(session, sendData, receiveType)
private fun abortFiber(): Nothing {
while (true) {
Fiber.park()
}
val sessionData = receivedSessionMessage.message.checkDataSessionMessage()
logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType)
}
private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage {
when (payload) {
is DataSessionMessage -> {
return payload
}
else -> {
throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
}
}
@Suspendable
override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party,
sessionFlow: FlowLogic<*>,
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
val session = getConfirmedSession(otherParty, sessionFlow)
val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage()
logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" }
return receivedSessionMessage.checkPayloadIs(receiveType)
}
private fun requireNonPrimitive(receiveType: Class<*>) {
require(!receiveType.isPrimitive) {
"Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class"
}
}
@Suspendable
override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) {
logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
if (session == null) {
// Don't send the payload again if it was already piggy-backed on a session init
initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false)
} else {
sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload)))
}
}
@Suspendable
override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction {
logger.debug { "waitForLedgerCommit($hash) ..." }
suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>))
val stx = serviceHub.validatedTransactions.getTransaction(hash)
if (stx != null) {
logger.debug { "Transaction $hash committed to ledger" }
return stx
}
// If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message.payload is ErrorSessionMessage) {
session.erroredEnd(receivedMessage.message.payload.flowException)
}
}
}
throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
}
// Provide a mechanism to sleep within a Strand without locking any transactional state.
// This checkpoints, since we cannot undo any database writes up to this point.
@Suspendable
override fun sleepUntil(until: Instant) {
suspend(Sleep(until, this))
}
// TODO Dummy implementation of access to application specific permission controls and audit logging
@ -305,242 +294,50 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly>()
for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, receiveType))
}
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session)
result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs(
requestAndMessage.request.userReceiveType as Class<out Any>
)
}
return result
}
internal fun pushToLoggingContext() = context.pushToLoggingContext()
/**
* This method will suspend the state machine and wait for incoming session init response from other party.
*/
@Suspendable
private fun FlowSessionInternal.waitForConfirmation() {
val sessionInitResponse = receiveInternal(this, null)
val payload = sessionInitResponse.message.payload
when (payload) {
is ConfirmSessionMessage -> {
state = FlowSessionState.Initiated(
sessionInitResponse.
peerParty,
payload.initiatedSessionId,
payload.initiatedFlowInfo)
}
is RejectSessionMessage -> {
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}")
}
else -> {
throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
}
}
private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage {
return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
}
@Suspendable
private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message))
@Suspendable
private fun receiveInternal(
session: FlowSessionInternal,
userReceiveType: Class<*>?): ReceivedSessionMessage {
return waitForMessage(ReceiveOnly(session, userReceiveType))
}
@Suspendable
private fun sendAndReceiveInternal(
session: FlowSessionInternal,
message: DataSessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage {
val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message)
return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType))
}
@Suspendable
private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? {
val session = openSessions[Pair(sessionFlow, otherParty)] ?: return null
return when (session.state) {
is FlowSessionState.Uninitiated -> null
is FlowSessionState.Initiating -> {
session.waitForConfirmation()
session
}
is FlowSessionState.Initiated -> session
}
}
@Suspendable
private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal {
return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?:
initiateSession(otherParty, sessionFlow, null, waitForConfirmation = true)
}
private fun createNewSession(
otherParty: Party,
flowSession: FlowSession,
sessionFlow: FlowLogic<*>
) {
logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session
}
@Suspendable
private fun initiateSession(
otherParty: Party,
sessionFlow: FlowLogic<*>,
firstPayload: Any?,
waitForConfirmation: Boolean,
retryable: Boolean = false
): FlowSessionInternal {
val session = openSessions[Pair(sessionFlow, otherParty)] ?: throw IllegalStateException("Expected an Uninitiated session for $otherParty")
val state = session.state as? FlowSessionState.Uninitiated ?: throw IllegalStateException("Tried to initiate a session $session, but it's already initiating/initiated")
logger.trace { "Initiating a new session with ${state.otherParty}" }
session.state = FlowSessionState.Initiating(state.otherParty)
session.retryable = retryable
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.")
val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()
}
return session
}
@Suspendable
private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage {
val receivedMessage = receiveRequest.suspendAndExpectReceive()
receivedMessage.message.confirmNoError(receiveRequest.session)
return receivedMessage
}
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@Suspendable
override fun invoke(request: FlowIORequest) {
suspend(request)
}
}
@Suspendable
private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage {
val polledMessage = session.receivedMessages.poll()
return if (polledMessage != null) {
if (this is SendAndReceive) {
// Since we've already received the message, we downgrade to a send only to get the payload out and not
// inadvertently block
suspend(SendOnly(session, message))
}
polledMessage
} else {
// Suspend while we wait for a receive
suspend(this)
session.receivedMessages.poll() ?:
throw IllegalStateException("Was expecting a message but instead got nothing for $this")
}
}
private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage {
when (payload) {
is ConfirmSessionMessage,
is DataSessionMessage -> {
return this
}
is ErrorSessionMessage -> {
openSessions.values.remove(session)
session.erroredEnd(payload.flowException)
}
is RejectSessionMessage -> {
session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}"))
}
EndSessionMessage -> {
openSessions.values.remove(session)
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending data")
}
}
}
private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing {
if (exception != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
exception.fillInStackTrace()
throw exception
} else {
throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
}
}
@Suspendable
private fun suspend(ioRequest: FlowIORequest) {
// We have to pass the thread local database transaction across via a transient field as the fiber park
// swaps them out.
txTrampoline = contextTransactionOrNull
contextTransactionOrNull = null
if (ioRequest is WaitingRequest)
waitingForResponse = ioRequest
var exceptionDuringSuspend: Throwable? = null
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext))
val transaction = extractThreadLocalTransaction()
parkAndSerialize { _, _ ->
logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
try {
contextTransactionOrNull = txTrampoline
txTrampoline = null
actionOnSuspend(ioRequest)
} catch (t: Throwable) {
// Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
// resume the fiber just so that we can throw it when it's running.
exceptionDuringSuspend = t
logger.trace("Resuming so fiber can it terminate with the exception thrown during suspend process", t)
resume(scheduler)
}
}
if (exceptionDuringSuspend == null && ioRequest is Sleep) {
// Sleep on the fiber. This will not sleep if it's in the past.
Strand.sleep(Duration.between(Instant.now(), ioRequest.until).toNanos(), TimeUnit.NANOSECONDS)
contextTransactionOrNull = transaction.value
val event = try {
Event.Suspend(
ioRequest = ioRequest,
maySkipCheckpoint = maySkipCheckpoint,
fiber = this.serialize(context = serializationContext.value)
)
} catch (throwable: Throwable) {
Event.Error(throwable)
}
// We must commit the database transaction before returning from this closure otherwise Quasar may schedule
// other fibers, so we process the event immediately
val continuation = processEventImmediately(
event,
isDbTransactionOpenOnEntry = true,
isDbTransactionOpenOnExit = false
)
require(continuation == FlowContinuation.ProcessEvents)
Fiber.unparkDeserialized(this, scheduler)
}
createTransaction()
// TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
// the fiber when exceptions occur inside a suspend.
exceptionDuringSuspend?.let { throw it }
logger.trace { "Resumed from $ioRequest" }
return uncheckedCast(processEventsUntilFlowIsResumed(
isDbTransactionOpenOnEntry = false,
isDbTransactionOpenOnExit = true
))
}
internal fun resume(scheduler: FiberScheduler) {
try {
if (fromCheckpoint) {
logger.info("Resumed from checkpoint")
fromCheckpoint = false
Fiber.unparkDeserialized(this, scheduler)
} else if (state == State.NEW) {
logger.trace("Started")
start()
} else {
Fiber.unpark(this, QUASAR_UNBLOCKER)
}
} catch (t: Throwable) {
logger.error("Error during resume", t)
}
@Suspendable
override fun scheduleEvent(event: Event) {
getTransientField(TransientValues::eventQueue).send(event)
}
override fun snapshot(): StateMachineState {
return transientState!!.value
}
override val stateMachine get() = getTransientField(TransientValues::stateMachine)
/**
* Records the duration of this flow from call() to completion or failure.
* Note that the duration will include the time the flow spent being parked, and not just the total
@ -582,15 +379,3 @@ val Class<out FlowLogic<*>>.appName: String
"<unknown>"
}
}
fun <T : Any> DataSessionMessage.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -0,0 +1,20 @@
package net.corda.node.services.statemachine
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
/**
* A simple [FlowHospital] implementation that immediately triggers error propagation when a flow dirties.
*/
object PropagatingFlowHospital : FlowHospital {
private val log = loggerFor<PropagatingFlowHospital>()
override fun flowErrored(flowFiber: FlowFiber) {
log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" }
flowFiber.scheduleEvent(Event.StartErrorPropagation)
}
override fun flowCleaned(flowFiber: FlowFiber) {
throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered")
}
}

View File

@ -0,0 +1,8 @@
package net.corda.node.services.statemachine
import net.corda.core.CordaException
/**
* An exception propagated and thrown in case a session initiation fails.
*/
class SessionRejectException(reason: String) : CordaException(reason)

View File

@ -0,0 +1,689 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.fibers.instrument.SuspendableHelper
import co.paralleluniverse.strands.channels.Channels
import com.codahale.metrics.Gauge
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.statemachine.interceptors.*
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.services.statemachine.transitions.StateMachineConfiguration
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
import rx.subjects.PublishSubject
import java.security.SecureRandom
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.ArrayList
import kotlin.streams.toList
/**
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
* thread actually starts them via [startFlow].
*/
@ThreadSafe
class SingleThreadedStateMachineManager(
val serviceHub: ServiceHubInternal,
val checkpointStorage: CheckpointStorage,
val executor: ExecutorService,
val database: CordaPersistence,
val secureRandom: SecureRandom,
private val unfinishedFibers: ReusableLatch = ReusableLatch(),
private val classloader: ClassLoader = SingleThreadedStateMachineManager::class.java.classLoader
) : StateMachineManager, StateMachineManagerInternal {
companion object {
private val logger = contextLogger()
}
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private class InnerState {
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
// True if we're shutting down, so don't resume anything.
var stopping = false
val flows = HashMap<StateMachineRunId, Flow>()
val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>()
}
private val mutex = ThreadBox(InnerState())
private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor)
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support.
private val metrics = serviceHub.monitoringService.metrics
private val sessionToFlow = ConcurrentHashMap<SessionId, StateMachineRunId>()
private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub)
private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null
private val transitionExecutor = makeTransitionExecutor()
private var checkpointSerializationContext: SerializationContext? = null
private var tokenizableServices: List<Any>? = null
private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>>
get() = mutex.locked { flows.values.map { it.fiber.logic } }
private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished")
/**
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
* which may change across restarts.
*
* We use assignment here so that multiple subscribers share the same wrapped Observable.
*/
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
this.tokenizableServices = tokenizableServices
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
fiberDeserializationChecker?.start(checkpointSerializationContext)
val fibers = restoreFlowsFromCheckpoints()
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.flows.size })
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
}
serviceHub.networkMapCache.nodeReady.then {
resumeRestoredFlows(fibers)
flowMessaging.start { receivedMessage, deduplicationHandler ->
executor.execute {
onSessionMessage(receivedMessage, deduplicationHandler)
}
}
}
}
override fun resume() {
fiberDeserializationChecker?.start(checkpointSerializationContext!!)
val fibers = restoreFlowsFromCheckpoints()
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
}
serviceHub.networkMapCache.nodeReady.then {
resumeRestoredFlows(fibers)
}
mutex.locked {
stopping = false
}
}
override fun <A : FlowLogic<*>> findStateMachines(flowClass: Class<A>): List<Pair<A, CordaFuture<*>>> {
return mutex.locked {
flows.values.mapNotNull {
flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture }
}
}
}
/**
* Start the shutdown process, bringing the [SingleThreadedStateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
override fun stop(allowedUnsuspendedFiberCount: Int) {
require(allowedUnsuspendedFiberCount >= 0)
mutex.locked {
if (stopping) throw IllegalStateException("Already stopping!")
stopping = true
for ((_, flow) in flows) {
flow.fiber.scheduleEvent(Event.SoftShutdown)
}
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
fiberDeserializationChecker?.let {
val foundUnrestorableFibers = it.stop()
check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." }
}
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
*/
override fun track(): DataFeed<List<FlowLogic<*>>, StateMachineManager.Change> {
return mutex.locked {
DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed())
}
}
override fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>> {
return startFlowInternal(
invocationContext = context,
flowLogic = flowLogic,
flowStart = FlowStart.Explicit,
ourIdentity = ourIdentity ?: getOurFirstIdentity(),
deduplicationHandler = deduplicationHandler,
isStartIdempotent = false
)
}
override fun killFlow(id: StateMachineRunId): Boolean {
return mutex.locked {
val flow = flows.remove(id)
if (flow != null) {
logger.debug("Killing flow known to physical node.")
decrementLiveFibers()
totalFinishedFlows.inc()
unfinishedFibers.countDown()
try {
flow.fiber.interrupt()
true
} finally {
database.transaction {
checkpointStorage.removeCheckpoint(id)
}
}
} else {
// TODO replace with a clustered delete after we'll support clustered nodes
logger.debug("Unable to kill a flow unknown to physical node. Might be processed by another physical node.")
false
}
}
}
override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) {
val previousFlowId = sessionToFlow.put(sessionId, flowId)
if (previousFlowId != null) {
if (previousFlowId == flowId) {
logger.warn("Session binding from $sessionId to $flowId re-added")
} else {
throw IllegalStateException(
"Attempted to add session binding from session $sessionId to flow $flowId, " +
"however there was already a binding to $previousFlowId"
)
}
}
}
override fun removeSessionBindings(sessionIds: Set<SessionId>) {
val reRemovedSessionIds = HashSet<SessionId>()
for (sessionId in sessionIds) {
val flowId = sessionToFlow.remove(sessionId)
if (flowId == null) {
reRemovedSessionIds.add(sessionId)
}
}
if (reRemovedSessionIds.isNotEmpty()) {
logger.warn("Session binding from $reRemovedSessionIds re-removed")
}
}
override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) {
mutex.locked {
val flow = flows.remove(flowId)
if (flow != null) {
decrementLiveFibers()
totalFinishedFlows.inc()
unfinishedFibers.countDown()
return when (removalReason) {
is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState)
is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState)
FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown)
}
} else {
logger.warn("Flow $flowId re-finished")
}
}
}
override fun signalFlowHasStarted(flowId: StateMachineRunId) {
mutex.locked {
startedFutures.remove(flowId)?.set(Unit)
flows[flowId]?.let { flow ->
changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic))
}
}
}
private val stateMachineConfiguration = StateMachineConfiguration.default
private fun checkQuasarJavaAgentPresence() {
check(SuspendableHelper.isJavaAgentActive(), {
"""Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM.
#See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#")
})
}
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
private fun restoreFlowsFromCheckpoints(): List<Flow> {
return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) ->
// If a flow is added before start() then don't attempt to restore it
mutex.locked { if (flows.containsKey(id)) return@map null }
val checkpoint = deserializeCheckpoint(serializedCheckpoint)
if (checkpoint == null) return@map null
createFlowFromCheckpoint(
id = id,
checkpoint = checkpoint,
initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true,
isStartIdempotent = false
)
}.toList().filterNotNull()
}
private fun resumeRestoredFlows(flows: List<Flow>) {
for (flow in flows) {
addAndStartFlow(flow.fiber.id, flow)
}
}
private fun onSessionMessage(message: ReceivedMessage, deduplicationHandler: DeduplicationHandler) {
val peer = message.peer
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from $peer")
deduplicationHandler.afterDatabaseTransaction()
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, deduplicationHandler, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, deduplicationHandler, sender)
}
} else {
logger.error("Unknown peer $peer in $sessionMessage")
}
}
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, deduplicationHandler: DeduplicationHandler, sender: Party) {
try {
val recipientId = sessionMessage.recipientSessionId
val flowId = sessionToFlow[recipientId]
if (flowId == null) {
deduplicationHandler.afterDatabaseTransaction()
if (sessionMessage.payload is EndSessionMessage) {
logger.debug {
"Got ${EndSessionMessage::class.java.simpleName} for " +
"unknown session $recipientId, discarding..."
}
} else {
throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId")
}
} else {
val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
}
} catch (exception: Exception) {
logger.error("Exception while routing $sessionMessage", exception)
throw exception
}
}
private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, deduplicationHandler: DeduplicationHandler, sender: Party) {
fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage {
val errorId = secureRandom.nextLong()
val payload = RejectSessionMessage(message, errorId)
return ExistingSessionMessage(initiatorSessionId, payload)
}
val replyError = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
val initiatedSessionId = SessionId.createRandom(secureRandom)
val senderSession = FlowSessionImpl(sender, initiatedSessionId)
val flowLogic = initiatedFlowFactory.createFlow(senderSession)
val initiatedFlowInfo = when (initiatedFlowFactory) {
is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda")
is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName)
}
val senderCoreFlowVersion = when (initiatedFlowFactory) {
is InitiatedFlowFactory.Core -> senderPlatformVersion
is InitiatedFlowFactory.CorDapp -> null
}
startInitiatedFlow(flowLogic, deduplicationHandler, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo)
null
} catch (exception: Exception) {
logger.warn("Exception while creating initiated flow", exception)
createErrorMessage(
sessionMessage.initiatorSessionId,
(exception as? SessionRejectException)?.message ?: "Unable to establish session"
)
}
if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
deduplicationHandler.afterDatabaseTransaction()
}
}
// TODO this is a temporary hack until we figure out multiple identities
private fun getOurFirstIdentity(): Party {
return serviceHub.myInfo.legalIdentities[0]
}
private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> {
val initiatingFlowClass = try {
Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
} catch (e: ClassNotFoundException) {
throw SessionRejectException("Don't know ${message.initiatorFlowClassName}")
} catch (e: ClassCastException) {
throw SessionRejectException("${message.initiatorFlowClassName} is not a flow")
}
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
throw SessionRejectException("$initiatingFlowClass is not registered")
}
private fun <A> startInitiatedFlow(
flowLogic: FlowLogic<A>,
initiatingMessageDeduplicationHandler: DeduplicationHandler,
peerSession: FlowSessionImpl,
initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage,
senderCoreFlowVersion: Int?,
initiatedFlowInfo: FlowInfo
) {
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo)
val ourIdentity = getOurFirstIdentity()
startFlowInternal(
InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity,
initiatingMessageDeduplicationHandler,
isStartIdempotent = false
)
}
private fun <A> startFlowInternal(
invocationContext: InvocationContext,
flowLogic: FlowLogic<A>,
flowStart: FlowStart,
ourIdentity: Party,
deduplicationHandler: DeduplicationHandler?,
isStartIdempotent: Boolean
): CordaFuture<FlowStateMachine<A>> {
val flowId = StateMachineRunId.createRandom()
val deduplicationSeed = when (flowStart) {
FlowStart.Explicit -> flowId.uuid.toString()
is FlowStart.Initiated ->
"${flowStart.initiatingMessage.initiatorSessionId.toLong}-" +
"${flowStart.initiatingMessage.initiationEntropy}"
}
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
// have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow()
val startedFuture = openFuture<Unit>()
val initialState = StateMachineState(
checkpoint = initialCheckpoint,
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = false,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = flowLogic
)
flowStateMachineImpl.transientState = TransientReference(initialState)
mutex.locked {
startedFutures[flowId] = startedFuture
}
totalStartedFlows.inc()
addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture))
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
}
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
}
}
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
//
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
// bridge, so we have to do a more complex scan here.
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
if (call.getAnnotation(Suspendable::class.java) == null) {
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
}
}
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
return FlowStateMachineImpl.TransientValues(
eventQueue = Channels.newChannel(stateMachineConfiguration.eventQueueSize, Channels.OverflowPolicy.BLOCK),
resultFuture = resultFuture,
database = database,
transitionExecutor = transitionExecutor,
actionExecutor = actionExecutor!!,
stateMachine = StateMachine(id, stateMachineConfiguration, secureRandom),
serviceHub = serviceHub,
checkpointSerializationContext = checkpointSerializationContext!!
)
}
private fun createFlowFromCheckpoint(
id: StateMachineRunId,
checkpoint: Checkpoint,
isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean,
initialDeduplicationHandler: DeduplicationHandler?
): Flow {
val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = logic
)
val fiber = FlowStateMachineImpl(id, logic, scheduler)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = fiber.logic
)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
}
verifyFlowLogicIsSuspendable(fiber.logic)
return Flow(fiber, resultFuture)
}
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) {
val checkpoint = flow.fiber.snapshot().checkpoint
for (sessionId in getFlowSessionIds(checkpoint)) {
sessionToFlow.put(sessionId, id)
}
mutex.locked {
if (stopping) {
startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping"))
logger.trace("Not resuming as SMM is stopping.")
} else {
incrementLiveFibers()
unfinishedFibers.countUp()
flows.put(id, flow)
flow.fiber.scheduleEvent(Event.DoRemainingWork)
when (checkpoint.flowState) {
is FlowState.Unstarted -> {
flow.fiber.start()
}
is FlowState.Started -> {
Fiber.unparkDeserialized(flow.fiber, scheduler)
}
}
}
}
}
private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> {
val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated
return if (initiatedFlowStart == null) {
checkpoint.sessions.keys
} else {
checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,
flowMessaging,
this,
checkpointSerializationContext,
metrics
)
}
private fun makeTransitionExecutor(): TransitionExecutor {
val interceptors = ArrayList<TransitionInterceptor>()
interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) }
if (serviceHub.configuration.devMode) {
interceptors.add { DumpHistoryOnErrorInterceptor(it) }
}
if (serviceHub.configuration.shouldCheckCheckpoints()) {
interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) }
}
if (logger.isDebugEnabled) {
interceptors.add { PrintingInterceptor(it) }
}
val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database)
return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) }
}
private fun InnerState.removeFlowOrderly(
flow: Flow,
removalReason: FlowRemovalReason.OrderlyFinish,
lastState: StateMachineState
) {
drainFlowEventQueue(flow)
// final sanity checks
require(lastState.pendingDeduplicationHandlers.isEmpty())
require(lastState.isRemoved)
require(lastState.checkpoint.subFlowStack.size == 1)
sessionToFlow.none { it.value == flow.fiber.id }
flow.resultFuture.set(removalReason.flowReturnValue)
lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue)))
}
private fun InnerState.removeFlowError(
flow: Flow,
removalReason: FlowRemovalReason.ErrorFinish,
lastState: StateMachineState
) {
drainFlowEventQueue(flow)
val flowError = removalReason.flowErrors[0] // TODO what to do with several?
val exception = flowError.exception
(exception as? FlowException)?.originalErrorId = flowError.errorId
flow.resultFuture.setException(exception)
lastState.flowLogic.progressTracker?.endWithError(exception)
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure<Nothing>(exception)))
}
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
private fun drainFlowEventQueue(flow: Flow) {
while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
when (event) {
is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> {
// Acknowledge the message so it doesn't leak in the broker.
event.deduplicationHandler.afterDatabaseTransaction()
when (event.sessionMessage.payload) {
EndSessionMessage -> {
logger.debug { "Unhandled message ${event.sessionMessage} due to flow shutting down" }
}
else -> {
logger.warn("Unhandled message ${event.sessionMessage} due to flow shutting down")
}
}
}
else -> {
logger.warn("Unhandled event $event due to flow shutting down")
}
}
}
}
}

View File

@ -1,11 +1,14 @@
package net.corda.node.services.statemachine
import net.corda.core.concurrent.CordaFuture
import net.corda.core.flows.FlowLogic
import net.corda.core.internal.FlowStateMachine
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.messaging.DataFeed
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
import rx.Observable
/**
@ -23,7 +26,6 @@ import rx.Observable
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
* TODO: Timeouts
* TODO: Surfacing of exceptions via an API and/or management UI
* TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt
* TODO: Don't store all active flows in memory, load from the database on demand.
*/
interface StateMachineManager {
@ -37,13 +39,25 @@ interface StateMachineManager {
*/
fun stop(allowedUnsuspendedFiberCount: Int)
/**
* Resume state machine manager after having called [stop].
*/
fun resume()
/**
* Starts a new flow.
*
* @param flowLogic The flow's code.
* @param context The context of the flow.
* @param ourIdentity The identity to use for the flow.
* @param deduplicationHandler Allows exactly-once start of the flow, see [DeduplicationHandler].
*/
fun <A> startFlow(flowLogic: FlowLogic<A>, context: InvocationContext): CordaFuture<FlowStateMachine<A>>
fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party?,
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>>
/**
* Represents an addition/removal of a state machine.
@ -73,4 +87,20 @@ interface StateMachineManager {
* Returns all currently live flows.
*/
val allStateMachines: List<FlowLogic<*>>
}
/**
* Attempts to kill a flow. This is not a clean termination and should be reserved for exceptional cases such as stuck fibers.
*
* @return whether the flow existed and was killed.
*/
fun killFlow(id: StateMachineRunId): Boolean
}
// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call
// these functions again
interface StateMachineManagerInternal {
fun signalFlowHasStarted(flowId: StateMachineRunId)
fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId)
fun removeSessionBindings(sessionIds: Set<SessionId>)
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
}

View File

@ -1,666 +0,0 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.fibers.instrument.SuspendableHelper
import co.paralleluniverse.strands.Strand
import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.KryoException
import com.google.common.collect.HashMultimap
import com.google.common.util.concurrent.MoreExecutors
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.context.InvocationOrigin
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.newNamedSingleThreadExecutor
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import org.slf4j.Logger
import rx.Observable
import rx.subjects.PublishSubject
import java.io.NotSerializableException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit.SECONDS
import javax.annotation.concurrent.ThreadSafe
/**
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
* thread actually starts them via [startFlow].
*/
@ThreadSafe
class StateMachineManagerImpl(
val serviceHub: ServiceHubInternal,
val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor,
val database: CordaPersistence,
private val unfinishedFibers: ReusableLatch = ReusableLatch(),
private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader
) : StateMachineManager {
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
companion object {
private val logger = contextLogger()
internal val sessionTopic = "platform.session"
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
}
}
}
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private class InnerState {
var started = false
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
val fibersWaitingForLedgerCommit = HashMultimap.create<SecureHash, FlowStateMachineImpl<*>>()!!
fun notifyChangeObservers(change: StateMachineManager.Change) {
changesPublisher.bufferUntilDatabaseCommit().onNext(change)
}
}
private val scheduler = FiberScheduler()
private val mutex = ThreadBox(InnerState())
// This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore.
private val checkpointCheckerThread = if (serviceHub.configuration.shouldCheckCheckpoints()) {
newNamedSingleThreadExecutor("CheckpointChecker")
} else {
null
}
@Volatile private var unrestorableCheckpoints = false
// True if we're shutting down, so don't resume anything.
@Volatile private var stopping = false
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support.
private val metrics = serviceHub.monitoringService.metrics
init {
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.stateMachines.size })
}
private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate")
private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished")
private val openSessions = ConcurrentHashMap<SessionId, FlowSessionInternal>()
private val recentlyClosedSessions = ConcurrentHashMap<SessionId, Party>()
// Context for tokenized services in checkpoints
private lateinit var tokenizableServices: List<Any>
private val serializationContext by lazy {
SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub)
}
/** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */
override fun <A : FlowLogic<*>> findStateMachines(flowClass: Class<A>): List<Pair<A, CordaFuture<*>>> {
return mutex.locked {
stateMachines.keys.mapNotNull {
flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast<FlowStateMachine<*>, FlowStateMachineImpl<*>>(it.stateMachine).resultFuture }
}
}
}
override val allStateMachines: List<FlowLogic<*>>
get() = mutex.locked { stateMachines.keys.map { it.logic } }
/**
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
* which may change across restarts.
*
* We use assignment here so that multiple subscribers share the same wrapped Observable.
*/
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher.wrapWithDatabaseTransaction()
override fun start(tokenizableServices: List<Any>) {
this.tokenizableServices = tokenizableServices
checkQuasarJavaAgentPresence()
restoreFibersFromCheckpoints()
listenToLedgerTransactions()
serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) }
}
private fun checkQuasarJavaAgentPresence() {
check(SuspendableHelper.isJavaAgentActive(), {
"""Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM.
#See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#")
})
}
private fun listenToLedgerTransactions() {
// Observe the stream of committed, validated transactions and resume fibers that are waiting for them.
serviceHub.validatedTransactions.updates.subscribe { stx ->
val hash = stx.id
val fibers: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) }
if (fibers.isNotEmpty()) {
executor.executeASAP {
for (fiber in fibers) {
fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" }
fiber.waitingForResponse = null
resumeFiber(fiber)
}
}
}
}
}
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
/**
* Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
override fun stop(allowedUnsuspendedFiberCount: Int) {
require(allowedUnsuspendedFiberCount >= 0)
mutex.locked {
if (stopping) throw IllegalStateException("Already stopping!")
stopping = true
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) }
check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." }
scheduler.shutdown()
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
*/
override fun track(): DataFeed<List<FlowLogic<*>>, StateMachineManager.Change> {
return mutex.locked {
DataFeed(stateMachines.keys.map { it.logic }, changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
}
}
private fun restoreFibersFromCheckpoints() {
mutex.locked {
checkpointStorage.forEach { checkpoint ->
// If a flow is added before start() then don't attempt to restore it
if (!stateMachines.containsValue(checkpoint)) {
deserializeFiber(checkpoint, logger)?.let {
initFiber(it)
stateMachines[it] = checkpoint
}
}
true
}
}
}
private fun resumeRestoredFibers() {
mutex.locked {
started = true
stateMachines.keys.forEach { resumeRestoredFiber(it) }
}
serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ ->
executor.checkOnThread()
onSessionMessage(message)
}
}
private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) {
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
val waitingForResponse = fiber.waitingForResponse
if (waitingForResponse != null) {
if (waitingForResponse is WaitForLedgerCommit) {
val stx = database.transaction {
serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash)
}
if (stx != null) {
fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed")
fiber.waitingForResponse = null
resumeFiber(fiber)
} else {
fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}")
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) }
}
} else {
fiber.logger.info("Restored, pending on receive")
}
} else {
resumeFiber(fiber)
}
}
private fun onSessionMessage(message: ReceivedMessage) {
val peer = message.peer
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from $peer")
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender)
}
} else {
logger.error("Unknown peer $peer in $sessionMessage")
}
}
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
val session = openSessions[message.recipientSessionId]
if (session != null) {
session.fiber.pushToLoggingContext()
session.fiber.logger.trace { "Received $message on $session from $sender" }
if (session.retryable) {
if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) {
session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} session is idempotent" }
return
}
if (message.payload !is ConfirmSessionMessage) {
serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong)
}
}
if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += ReceivedSessionMessage(sender, message)
if (resumeOnMessage(message, session)) {
// It's important that we reset here and not after the fiber's resumed, in case we receive another message
// before then.
session.fiber.waitingForResponse = null
updateCheckpoint(session.fiber)
session.fiber.logger.trace { "Resuming due to $message" }
resumeFiber(session.fiber)
}
} else {
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (peerParty != null) {
if (message.payload is ConfirmSessionMessage) {
logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage))
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
} else {
logger.warn("Received a session message for unknown session: $message, from $sender")
}
}
}
// We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
val waitingForResponse = session.fiber.waitingForResponse
return waitingForResponse?.shouldResume(message, session) ?: false
}
private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) {
logger.trace { "Received $sessionInit from $sender" }
val senderSessionId = sessionInit.initiatorSessionId
fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, errorId = sessionInit.initiatorSessionId.toLong)))
val (session, initiatedFlowFactory) = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit)
val flowSession = FlowSessionImpl(sender)
val flow = initiatedFlowFactory.createFlow(flowSession)
val senderFlowVersion = when (initiatedFlowFactory) {
is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version
is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion
}
val session = FlowSessionInternal(
flow,
flowSession,
SessionId.createRandom(newSecureRandom()),
sender,
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload)))
}
openSessions[session.ourSessionId] = session
val context = InvocationContext.peer(sender.name)
val fiber = createFiber(flow, context)
fiber.pushToLoggingContext()
logger.info("Accepting flow session from party ${sender.name}. Session id for tracing purposes is ${sessionInit.initiatorSessionId}.")
flowSession.sessionFlow = flow
flowSession.stateMachine = fiber
fiber.openSessions[Pair(flow, sender)] = session
updateCheckpoint(fiber)
session to initiatedFlowFactory
} catch (e: SessionRejectException) {
logger.warn("${e.logMessage}: $sessionInit")
sendSessionReject(e.rejectMessage)
return
} catch (e: Exception) {
logger.warn("Couldn't start flow session from $sessionInit", e)
sendSessionReject("Unable to establish session")
return
}
val (ourFlowVersion, appName) = when (initiatedFlowFactory) {
// The flow version for the core flows is the platform version
is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda"
is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName
}
sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber)
}
private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> {
val initiatingFlowClass = try {
Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
} catch (e: ClassNotFoundException) {
throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}")
} catch (e: ClassCastException) {
throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow")
}
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
throw SessionRejectException("$initiatingFlowClass is not registered")
}
private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes<FlowStateMachineImpl<*>> {
return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext))
}
private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? {
return try {
checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply {
fromCheckpoint = true
}
} catch (t: Throwable) {
logger.error("Encountered unrestorable checkpoint!", t)
null
}
}
private fun <T> createFiber(logic: FlowLogic<T>, context: InvocationContext, ourIdentity: Party? = null): FlowStateMachineImpl<T> {
val fsm = FlowStateMachineImpl(
StateMachineRunId.createRandom(),
logic,
scheduler,
ourIdentity ?: serviceHub.myInfo.legalIdentities[0],
context)
initFiber(fsm)
return fsm
}
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
verifyFlowLogicIsSuspendable(fiber.logic)
fiber.database = database
fiber.serviceHub = serviceHub
fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity }
?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity.name}) is not one of ours!")
fiber.actionOnSuspend = { ioRequest ->
updateCheckpoint(fiber)
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
fiber.commitTransaction()
processIORequest(ioRequest)
decrementLiveFibers()
}
fiber.actionOnEnd = { result, propagated ->
try {
mutex.locked {
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(StateMachineManager.Change.Removed(fiber.logic, result))
}
endAllFiberSessions(fiber, result, propagated)
} finally {
fiber.commitTransaction()
decrementLiveFibers()
totalFinishedFlows.inc()
unfinishedFibers.countDown()
}
}
mutex.locked {
totalStartedFlows.inc()
unfinishedFibers.countUp()
notifyChangeObservers(StateMachineManager.Change.Add(fiber.logic))
}
}
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
//
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
// bridge, so we have to do a more complex scan here.
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
if (call.getAnnotation(Suspendable::class.java) == null) {
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
}
}
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) {
openSessions.values.removeIf { session ->
if (session.fiber == fiber) {
session.endSession(fiber.context, (result as? Try.Failure)?.exception, propagated)
true
} else {
false
}
}
}
private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) {
val initiatedState = state as? FlowSessionState.Initiated ?: return
val sessionEnd = if (exception == null) {
EndSessionMessage
} else {
val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
// Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
// pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with.
exception
} else {
null
}
ErrorSessionMessage(errorResponse, 0)
}
sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber)
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
}
/**
* Kicks off a brand new state machine of the given class.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*
* Note that you must be on the [executor] thread.
*/
override fun <A> startFlow(flowLogic: FlowLogic<A>, context: InvocationContext): CordaFuture<FlowStateMachine<A>> {
// TODO: Check that logic has @Suspendable on its call method.
executor.checkOnThread()
val fiber = database.transaction {
val fiber = createFiber(flowLogic, context)
updateCheckpoint(fiber)
fiber
}
// If we are not started then our checkpoint will be picked up during start
mutex.locked {
if (started) {
resumeFiber(fiber)
}
}
return doneFuture(fiber)
}
private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) {
check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
val newCheckpoint = Checkpoint(serializeFiber(fiber))
val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) }
if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint)
}
checkpointStorage.addCheckpoint(newCheckpoint)
checkpointingMeter.mark()
checkpointCheckerThread?.execute {
// Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have
// in our testing by failing any test where unrestorable checkpoints are created.
if (deserializeFiber(newCheckpoint, fiber.logger) == null) {
unrestorableCheckpoints = true
}
}
}
private fun resumeFiber(fiber: FlowStateMachineImpl<*>) {
// Avoid race condition when setting stopping to true and then checking liveFibers
incrementLiveFibers()
if (!stopping) {
executor.executeASAP {
fiber.resume(scheduler)
}
} else {
fiber.logger.trace("Not resuming as SMM is stopping.")
decrementLiveFibers()
}
}
private fun processIORequest(ioRequest: FlowIORequest) {
executor.checkOnThread()
when (ioRequest) {
is SendRequest -> processSendRequest(ioRequest)
is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest)
is Sleep -> processSleepRequest(ioRequest)
}
}
private fun processSendRequest(ioRequest: SendRequest) {
val retryId = if (ioRequest.message is InitialSessionMessage) {
with(ioRequest.session) {
openSessions[ourSessionId] = this
if (retryable) ourSessionId.toLong else null
}
} else null
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId)
if (ioRequest !is ReceiveRequest) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.fiber)
}
}
private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) {
// Is it already committed?
val stx = database.transaction {
serviceHub.validatedTransactions.getTransaction(ioRequest.hash)
}
if (stx != null) {
resumeFiber(ioRequest.fiber)
} else {
// No, then register to wait.
//
// We assume this code runs on the server thread, which is the only place transactions are committed
// currently. When we liberalise our threading somewhat, handing of wait requests will need to be
// reworked to make the wait atomic in another way. Otherwise there is a race between checking the
// database and updating the waiting list.
mutex.locked {
fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber
}
}
}
private fun processSleepRequest(ioRequest: Sleep) {
// Resume the fiber now we have checkpointed, so we can sleep on the Fiber.
resumeFiber(ioRequest.fiber)
}
private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) {
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
?: throw IllegalArgumentException("Don't know about party $party")
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val logger = fiber?.logger ?: logger
logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" }
val serialized = try {
message.serialize()
} catch (e: Exception) {
when (e) {
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
is KryoException,
is NotSerializableException -> {
if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) {
logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " +
"Instead sending back an exception which is serialisable to ensure session end occurs properly.", e)
// The subclass may have overridden toString so we use that
val exMessage = message.payload.flowException.message
message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize()
} else {
throw e
}
}
else -> throw e
}
}
// This prevents a "deadlock" in case an initiated flow tries to start a session against a draining node that is also the initiator.
// It does not help in case more than 2 nodes are involved in a circle, so the kill switch via RPC should be used in that case.
val additionalHeaders = if (mightDeadlockDrainingSender(fiber, party)) emptyMap() else message.additionalHeaders()
serviceHub.networkService.apply {
send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId, additionalHeaders = additionalHeaders)
}
}
private fun mightDeadlockDrainingSender(fiber: FlowStateMachineImpl<*>?, target: Party): Boolean {
return fiber?.context?.origin.let { it is InvocationOrigin.Peer && it.party == target.name }
}
}
private fun SessionMessage.additionalHeaders(): Map<String, String> {
return when (this) {
is InitialSessionMessage -> mapOf(P2PMessagingHeaders.Type.KEY to P2PMessagingHeaders.Type.SESSION_INIT_VALUE)
else -> emptyMap()
}
}
class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) {
constructor(message: String) : this(message, message)
}

View File

@ -0,0 +1,231 @@
package net.corda.node.services.statemachine
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
/**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
* persisted to the database ([Checkpoint]), and the rest, which is an in-memory-only state.
*
* @param checkpoint the persisted part of the state.
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
* @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
* to make [Event.DoRemainingWork] idempotent.
* @param isTransactionTracked true if a ledger transaction has been tracked as part of a
* [FlowIORequest.WaitForLedgerCommit]. This used is to make tracking idempotent.
* @param isAnyCheckpointPersisted true if at least a single checkpoint has been persisted. This is used to determine
* whether we should DELETE the checkpoint at the end of the flow.
* @param isStartIdempotent true if the start of the flow is idempotent, making the skipping of the initial checkpoint
* possible.
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
* work.
*/
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
data class StateMachineState(
val checkpoint: Checkpoint,
val flowLogic: FlowLogic<*>,
val pendingDeduplicationHandlers: List<DeduplicationHandler>,
val isFlowResumed: Boolean,
val isTransactionTracked: Boolean,
val isAnyCheckpointPersisted: Boolean,
val isStartIdempotent: Boolean,
val isRemoved: Boolean
)
/**
* @param invocationContext the initiator of the flow.
* @param ourIdentity the identity the flow is run as.
* @param sessions map of source session ID to session state.
* @param subFlowStack the stack of currently executing subflows.
* @param flowState the state of the flow itself, including the frozen fiber/FlowLogic.
* @param errorState the "dirtiness" state including the involved errors and their propagation status.
* @param numberOfSuspends the number of flow suspends due to IO API calls.
* @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs.
*/
data class Checkpoint(
val invocationContext: InvocationContext,
val ourIdentity: Party,
val sessions: SessionMap, // This must preserve the insertion order!
val subFlowStack: List<SubFlow>,
val flowState: FlowState,
val errorState: ErrorState,
val numberOfSuspends: Int,
val deduplicationSeed: String
) {
companion object {
fun create(
invocationContext: InvocationContext,
flowStart: FlowStart,
flowLogicClass: Class<FlowLogic<*>>,
frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
ourIdentity: Party,
deduplicationSeed: String
): Try<Checkpoint> {
return SubFlow.create(flowLogicClass).map { topLevelSubFlow ->
Checkpoint(
invocationContext = invocationContext,
ourIdentity = ourIdentity,
sessions = emptyMap(),
subFlowStack = listOf(topLevelSubFlow),
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
errorState = ErrorState.Clean,
numberOfSuspends = 0,
deduplicationSeed = deduplicationSeed
)
}
}
}
}
/**
* The state of a session.
*/
sealed class SessionState {
/**
* We haven't yet sent the initialisation message
*/
data class Uninitiated(
val party: Party,
val initiatingSubFlow: SubFlow.Initiating
) : SessionState()
/**
* We have sent the initialisation message but have not yet received a confirmation.
* @property rejectionError if non-null the initiation failed.
*/
data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
val rejectionError: FlowError?
) : SessionState()
/**
* We have received a confirmation, the peer party and session id is resolved.
* @property errors if not empty the session is in an errored state.
*/
data class Initiated(
val peerParty: Party,
val peerFlowInfo: FlowInfo,
val receivedMessages: List<DataSessionMessage>,
val initiatedState: InitiatedSessionState,
val errors: List<FlowError>
) : SessionState()
}
typealias SessionMap = Map<SessionId, SessionState>
/**
* Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest
* of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future
* [FlowInfo] requests.
*/
sealed class InitiatedSessionState {
data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState()
object Ended : InitiatedSessionState() { override fun toString() = "Ended" }
}
/**
* Represents the way the flow has started.
*/
sealed class FlowStart {
/**
* The flow was started explicitly e.g. through RPC or a scheduled state.
*/
object Explicit : FlowStart() { override fun toString() = "Explicit" }
/**
* The flow was started implicitly as part of session initiation.
*/
data class Initiated(
val peerSession: FlowSessionImpl,
val initiatedSessionId: SessionId,
val initiatingMessage: InitialSessionMessage,
val senderCoreFlowVersion: Int?,
val initiatedFlowInfo: FlowInfo
) : FlowStart() { override fun toString() = "Initiated" }
}
/**
* Represents the user-space related state of the flow.
*/
sealed class FlowState {
/**
* The flow's unstarted state. We should always be able to start a fresh flow fiber from this datastructure.
*
* @param flowStart How the flow was started.
* @param frozenFlowLogic The serialized user-provided [FlowLogic].
*/
data class Unstarted(
val flowStart: FlowStart,
val frozenFlowLogic: SerializedBytes<FlowLogic<*>>
) : FlowState() {
override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash})"
}
/**
* The flow's started state, this means the user-code has suspended on an IO request.
*
* @param flowIORequest what IO request the flow has suspended on.
* @param frozenFiber the serialized fiber itself.
*/
data class Started(
val flowIORequest: FlowIORequest<*>,
val frozenFiber: SerializedBytes<FlowStateMachineImpl<*>>
) : FlowState() {
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
}
}
/**
* @param errorId the ID of the error. This is generated once for the source error and is propagated to neighbour
* sessions.
* @param exception the exception itself. Note that this may not contain information about the source error depending
* on whether the source error was a FlowException or otherwise.
*/
data class FlowError(val errorId: Long, val exception: Throwable)
/**
* The flow's error state.
*/
sealed class ErrorState {
abstract fun addErrors(newErrors: List<FlowError>): ErrorState
/**
* The flow is in a clean state.
*/
object Clean : ErrorState() {
override fun addErrors(newErrors: List<FlowError>): ErrorState {
return Errored(newErrors, 0, false)
}
override fun toString() = "Clean"
}
/**
* The flow has dirtied because of an uncaught exception from user code or other error condition during a state
* transition.
* @param errors the list of errors. Multiple errors may be associated with the errored flow e.g. when multiple
* sessions are errored and have been waited on.
* @param propagatedIndex the index of the first error that hasn't yet been propagated.
* @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the
* sessions associated with the flow have been (or about to be) dirtied in counter-flows.
*/
data class Errored(
val errors: List<FlowError>,
val propagatedIndex: Int,
val propagating: Boolean
) : ErrorState() {
override fun addErrors(newErrors: List<FlowError>): ErrorState {
return copy(errors = errors + newErrors)
}
}
}

View File

@ -0,0 +1,74 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.utilities.Try
/**
* A [SubFlow] contains metadata about a currently executing sub-flow. At any point the flow execution is
* characterised with a stack of [SubFlow]s. This stack is used to determine the initiating-initiated flow mapping.
*
* Note that Initiat*ed*ness is an orthogonal property of the top-level subflow, so we don't store any information about
* it here.
*/
sealed class SubFlow {
abstract val flowClass: Class<out FlowLogic<*>>
/**
* An inlined subflow.
*/
data class Inlined(override val flowClass: Class<FlowLogic<*>>) : SubFlow()
/**
* An initiating subflow.
* @param [flowClass] the concrete class of the subflow.
* @param [classToInitiateWith] an ancestor class of [flowClass] with the [InitiatingFlow] annotation, to be sent
* to the initiated side.
* @param flowInfo the [FlowInfo] associated with the initiating flow.
*/
data class Initiating(
override val flowClass: Class<FlowLogic<*>>,
val classToInitiateWith: Class<in FlowLogic<*>>,
val flowInfo: FlowInfo
) : SubFlow()
companion object {
fun create(flowClass: Class<FlowLogic<*>>): Try<SubFlow> {
// Are we an InitiatingFlow?
val initiatingAnnotations = getInitiatingFlowAnnotations(flowClass)
return when (initiatingAnnotations.size) {
0 -> {
Try.Success(Inlined(flowClass))
}
1 -> {
val initiatingAnnotation = initiatingAnnotations[0]
val flowContext = FlowInfo(initiatingAnnotation.second.version, flowClass.appName)
Try.Success(Initiating(flowClass, initiatingAnnotation.first, flowContext))
}
else -> {
Try.Failure(IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated " +
"once, however the following classes all have the annotation: " +
"${initiatingAnnotations.map { it.first }}"))
}
}
}
private fun <C> getSuperClasses(clazz: Class<C>): List<Class<in C>> {
var currentClass: Class<in C>? = clazz
val result = ArrayList<Class<in C>>()
while (currentClass != null) {
result.add(currentClass)
currentClass = currentClass.superclass
}
return result
}
private fun getInitiatingFlowAnnotations(flowClass: Class<FlowLogic<*>>): List<Pair<Class<in FlowLogic<*>>, InitiatingFlow>> {
return getSuperClasses(flowClass).mapNotNull { clazz ->
val initiatingAnnotation = clazz.getDeclaredAnnotation(InitiatingFlow::class.java)
initiatingAnnotation?.let { Pair(clazz, it) }
}
}
}
}

View File

@ -0,0 +1,25 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
/**
* An executor of state machine transitions. This is mostly a wrapper interface around an [ActionExecutor], but can be
* used to create interceptors of transitions.
*/
interface TransitionExecutor {
@Suspendable
fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState>
}
/**
* An interceptor of a transition. These are currently explicitly hooked up in [SingleThreadedStateMachineManager].
*/
typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor

View File

@ -0,0 +1,67 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import java.security.SecureRandom
/**
* This [TransitionExecutor] runs the transition actions using the passed in [ActionExecutor] and manually dirties the
* state on failure.
*
* If a failure happens when we're already transitioning into a errored state then the transition and the flow fiber is
* completely aborted to avoid error loops.
*/
class TransitionExecutorImpl(
val secureRandom: SecureRandom,
val database: CordaPersistence
) : TransitionExecutor {
private companion object {
val log = contextLogger()
}
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
contextDatabase = database
for (action in transition.actions) {
try {
actionExecutor.executeAction(fiber, action)
} catch (exception: Throwable) {
contextTransactionOrNull?.close()
if (transition.newState.checkpoint.errorState is ErrorState.Errored) {
// If we errored while transitioning to an error state then we cannot record the additional
// error as that may result in an infinite loop, e.g. error propagation fails -> record error -> propagate fails again.
// Instead we just keep around the old error state and wait for a new schedule, perhaps
// triggered from a flow hospital
log.error("Error while executing $action during transition to errored state, aborting transition", exception)
return Pair(FlowContinuation.Abort, previousState.copy(isFlowResumed = false))
} else {
// Otherwise error the state manually keeping the old flow state and schedule a DoRemainingWork
// to trigger error propagation
log.error("Error while executing $action, erroring state", exception)
val newState = previousState.copy(
checkpoint = previousState.checkpoint.copy(
errorState = previousState.checkpoint.errorState.addErrors(
listOf(FlowError(secureRandom.nextLong(), exception))
)
),
isFlowResumed = false
)
fiber.scheduleEvent(Event.DoRemainingWork)
return Pair(FlowContinuation.ProcessEvents, newState)
}
}
}
return Pair(transition.continuation, transition.newState)
}
}

View File

@ -0,0 +1,51 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap
/**
* This interceptor records a trace of all of the flows' states and transitions. If the flow dirties it dumps the trace
* transition to the logger.
*/
class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
companion object {
private val log = contextLogger()
}
private val records = ConcurrentHashMap<StateMachineRunId, ArrayList<TransitionDiagnosticRecord>>()
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation)
val record = records.compute(fiber.id) { _, record ->
(record ?: ArrayList()).apply { add(transitionRecord) }
}
if (nextState.checkpoint.errorState is ErrorState.Errored) {
log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}")
for (error in nextState.checkpoint.errorState.errors) {
log.warn("Flow ${fiber.id} error", error.exception)
}
}
if (nextState.isRemoved) {
records.remove(fiber.id)
}
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,95 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.util.concurrent.LinkedBlockingQueue
import kotlin.concurrent.thread
/**
* This interceptor checks whether a checkpointed fiber state can be deserialised in a separate thread.
*/
class FiberDeserializationCheckingInterceptor(
val fiberDeserializationChecker: FiberDeserializationChecker,
val delegate: TransitionExecutor
) : TransitionExecutor {
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val previousFlowState = previousState.checkpoint.flowState
val nextFlowState = nextState.checkpoint.flowState
if (nextFlowState is FlowState.Started) {
if (previousFlowState !is FlowState.Started || previousFlowState.frozenFiber != nextFlowState.frozenFiber) {
fiberDeserializationChecker.submitCheck(nextFlowState.frozenFiber)
}
}
return Pair(continuation, nextState)
}
}
/**
* A fiber deserialisation checker thread. It checks the queued up serialised checkpoints to see if they can be
* deserialised. This is only run in development mode to allow detecting of corrupt serialised checkpoints before they
* are actually used.
*/
class FiberDeserializationChecker {
companion object {
val log = contextLogger()
}
private sealed class Job {
class Check(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) : Job()
object Finish : Job()
}
private var checkerThread: Thread? = null
private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) {
require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) {
val job = jobQueue.take()
when (job) {
is Job.Check -> {
try {
job.serializedFiber.deserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true
}
}
Job.Finish -> {
return@thread
}
}
}
}
}
fun submitCheck(serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
jobQueue.add(Job.Check(serializedFiber))
}
/**
* Returns true if some unrestorable checkpoints were encountered, false otherwise
*/
fun stop(): Boolean {
jobQueue.add(Job.Finish)
checkerThread?.join()
checkerThread = null
return foundUnrestorableFibers
}
}

View File

@ -0,0 +1,46 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.util.concurrent.ConcurrentHashMap
/**
* This interceptor notifies the passed in [flowHospital] in case a flow went through a clean->errored or a errored->clean
* transition.
*/
class HospitalisingInterceptor(
private val flowHospital: FlowHospital,
private val delegate: TransitionExecutor
) : TransitionExecutor {
private val hospitalisedFlows = ConcurrentHashMap<StateMachineRunId, FlowFiber>()
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
when (nextState.checkpoint.errorState) {
ErrorState.Clean -> {
if (hospitalisedFlows.remove(fiber.id) != null) {
flowHospital.flowCleaned(fiber)
}
}
is ErrorState.Errored -> {
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
flowHospital.flowErrored(fiber)
}
}
}
if (nextState.isRemoved) {
hospitalisedFlows.remove(fiber.id)
}
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,24 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.MetricRegistry
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
class MetricInterceptor(val metrics: MetricRegistry, val delegate: TransitionExecutor): TransitionExecutor {
@Suspendable
override fun executeTransition(fiber: FlowFiber, previousState: StateMachineState, event: Event, transition: TransitionResult, actionExecutor: ActionExecutor): Pair<FlowContinuation, StateMachineState> {
val metricActionInterceptor = MetricActionInterceptor(metrics, actionExecutor)
return delegate.executeTransition(fiber, previousState, event, transition, metricActionInterceptor)
}
}
class MetricActionInterceptor(val metrics: MetricRegistry, val delegate: ActionExecutor): ActionExecutor {
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
val context = metrics.timer("Flows.Actions.${action.javaClass.simpleName}").time()
delegate.executeAction(fiber, action)
context.stop()
}
}

View File

@ -0,0 +1,31 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.time.Instant
/**
* This interceptor simply prints all state machine transitions. Useful for debugging.
*/
class PrintingInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
companion object {
val log = contextLogger()
}
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val transitionRecord = TransitionDiagnosticRecord(Instant.now(), fiber.id, previousState, nextState, event, transition, continuation)
log.info("Transition for flow ${fiber.id} $transitionRecord")
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,51 @@
package net.corda.node.services.statemachine.interceptors
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.StateMachineState
import net.corda.node.services.statemachine.transitions.TransitionResult
import net.corda.node.utilities.ObjectDiffer
import java.time.Instant
/**
* This is a diagnostic record that stores information about a state machine transition and provides pretty printing
* by diffing the two states.
*/
data class TransitionDiagnosticRecord(
val timestamp: Instant,
val flowId: StateMachineRunId,
val previousState: StateMachineState,
val nextState: StateMachineState,
val event: Event,
val transition: TransitionResult,
val continuation: FlowContinuation
) {
override fun toString(): String {
val diffIntended = ObjectDiffer.diff(previousState, transition.newState)
val diffNext = ObjectDiffer.diff(previousState, nextState)
return (
listOf(
"",
" --- Transition of flow $flowId ---",
" Timestamp: $timestamp",
" Event: $event",
" Actions: ",
" ${transition.actions.joinToString("\n ")}",
" Continuation: ${transition.continuation}"
) +
if (diffIntended != diffNext) {
listOf(
" Diff between previous and intended state:",
"${diffIntended?.toPaths()?.joinToString("")}"
)
} else {
emptyList()
} + listOf(
" Diff between previous and next state:",
"${diffNext?.toPaths()?.joinToString("")}"
)
).joinToString("\n")
}
}

View File

@ -0,0 +1,186 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.node.services.statemachine.*
/**
* This transition handles incoming session messages. It handles the following cases:
* - DataSessionMessage: these arrive to initiated and confirmed sessions and are expected to be received by the flow.
* - ConfirmSessionMessage: these arrive as a response to an InitialSessionMessage and include information about the
* counterparty flow's session ID as well as their [FlowInfo].
* - ErrorSessionMessage: these arrive to initiated and confirmed sessions and put the corresponding session into an
* "errored" state. This means that whenever that session is subsequently interacted with the error will be thrown
* in the flow.
* - RejectSessionMessage: these arrive as a response to an InitialSessionMessage when the initiation failed. It
* behaves similarly to ErrorSessionMessage aside from the type of exceptions stored/raised.
* - EndSessionMessage: these are sent when the counterparty flow has finished. They put the corresponding session into
* an "ended" state. This means that subsequent sends on this session will fail, and receives will start failing
* after the buffer of already received messages is drained.
*/
class DeliverSessionMessageTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val event: Event.DeliverSessionMessage
) : Transition {
override fun transition(): TransitionResult {
return builder {
// Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know
// about the message. Note that in case of an error during deliver this message *will be acked*.
// For example if the session corresponding to the message is not found the message is still acked to free
// up the broker but the flow will error.
currentState = currentState.copy(
pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler
)
// Check whether we have a session corresponding to the message.
val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId]
if (existingSession == null) {
freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId))
} else {
val payload = event.sessionMessage.payload
// Dispatch based on what kind of message it is.
val _exhaustive = when (payload) {
is ConfirmSessionMessage -> confirmMessageTransition(existingSession, payload)
is DataSessionMessage -> dataMessageTransition(existingSession, payload)
is ErrorSessionMessage -> errorMessageTransition(existingSession, payload)
is RejectSessionMessage -> rejectMessageTransition(existingSession, payload)
is EndSessionMessage -> endMessageTransition()
}
}
if (!isErrored()) {
persistCheckpoint()
}
// Schedule a DoRemainingWork to check whether the flow needs to be woken up.
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
FlowContinuation.ProcessEvents
}
}
private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) {
// We received a confirmation message. The corresponding session state must be Initiating.
when (sessionState) {
is SessionState.Initiating -> {
// Create the new session state that is now Initiated.
val initiatedSession = SessionState.Initiated(
peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(),
initiatedState = InitiatedSessionState.Live(message.initiatedSessionId),
errors = emptyList()
)
val newCheckpoint = currentState.checkpoint.copy(
sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession)
)
// Send messages that were buffered pending confirmation of session.
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId)
}
actions.addAll(sendActions)
currentState = currentState.copy(checkpoint = newCheckpoint)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
// We received a data message. The corresponding session must be Initiated.
return when (sessionState) {
is SessionState.Initiated -> {
// Buffer the message in the session's receivedMessages buffer.
val newSessionState = sessionState.copy(
receivedMessages = sessionState.receivedMessages + message
)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState)
)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
val exception: Throwable = if (payload.flowException == null) {
UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId)
} else {
payload.flowException.originalErrorId = payload.errorId
payload.flowException
}
return when (sessionState) {
is SessionState.Initiated -> {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception)
val newSessionState = sessionState.copy(errors = sessionState.errors + flowError)
currentState = currentState.copy(
checkpoint = checkpoint.copy(
sessions = checkpoint.sessions + (sessionId to newSessionState)
)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.rejectMessageTransition(sessionState: SessionState, payload: RejectSessionMessage) {
val exception = UnexpectedFlowEndException(payload.message, cause = null, originalErrorId = payload.errorId)
return when (sessionState) {
is SessionState.Initiating -> {
if (sessionState.rejectionError != null) {
// Double reject
freshErrorTransition(UnexpectedEventInState())
} else {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception)
currentState = currentState.copy(
checkpoint = checkpoint.copy(
sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError))
)
)
}
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.persistCheckpoint() {
// We persist the message as soon as it arrives.
actions.addAll(arrayOf(
Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
))
currentState = currentState.copy(
pendingDeduplicationHandlers = emptyList(),
isAnyCheckpointPersisted = true
)
}
private fun TransitionBuilder.endMessageTransition() {
val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.sessions
val sessionState = sessions[sessionId]
if (sessionState == null) {
return freshErrorTransition(CannotFindSessionException(sessionId))
}
when (sessionState) {
is SessionState.Initiated -> {
val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = sessions + (sessionId to newSessionState)
)
)
}
else -> {
freshErrorTransition(UnexpectedEventInState())
}
}
}
}

View File

@ -0,0 +1,37 @@
package net.corda.node.services.statemachine.transitions
import net.corda.node.services.statemachine.*
/**
* This transition checks the current state of the flow and determines whether anything needs to be done.
*/
class DoRemainingWorkTransition(
override val context: TransitionContext,
override val startingState: StateMachineState
) : Transition {
override fun transition(): TransitionResult {
val checkpoint = startingState.checkpoint
// If the flow is removed or has been resumed don't do work.
if (startingState.isFlowResumed || startingState.isRemoved) {
return TransitionResult(startingState)
}
// Check whether the flow is errored
return when (checkpoint.errorState) {
is ErrorState.Clean -> cleanTransition()
is ErrorState.Errored -> erroredTransition(checkpoint.errorState)
}
}
// If the flow is clean check the FlowState
private fun cleanTransition(): TransitionResult {
val checkpoint = startingState.checkpoint
return when (checkpoint.flowState) {
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition()
is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition()
}
}
private fun erroredTransition(errorState: ErrorState.Errored): TransitionResult {
return ErrorFlowTransition(context, startingState, errorState).transition()
}
}

View File

@ -0,0 +1,125 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.node.services.statemachine.*
/**
* This transition defines what should happen when a flow has errored.
*
* In general there are two flow-level error conditions:
*
* - Internal exceptions. These may arise due to problems in the flow framework or errors during state machine
* transitions e.g. network or database failure.
* - User-raised exceptions. These are exceptions that are (re)raised in user code, allowing the user to catch them.
* These may come from illegal flow API calls, and FlowExceptions or other counterparty failures that are re-raised
* when the flow tries to use the corresponding sessions.
*
* Both internal exceptions and uncaught user-raised exceptions cause the flow to be errored. This flags the flow as
* unable to be resumed. When a flow is in this state an external source (e.g. Flow hospital) may decide to
*
* 1. Retry it (not implemented yet). This throws away the errored state and re-tries from the last clean checkpoint.
* 2. Start error propagation. This seals the flow as errored permanently and propagates the associated error(s) to
* all live sessions. This causes these sessions to errored on the other side, which may in turn cause the
* counter-flows themselves to errored.
*
* See [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor] for how to detect flow errors.
*
* Note that in general we handle multiple errors at a time as several error conditions may arise at the same time and
* new errors may arise while the flow is in the errored state already.
*/
class ErrorFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
private val errorState: ErrorState.Errored
) : Transition {
override fun transition(): TransitionResult {
val allErrors: List<FlowError> = errorState.errors
val remainingErrorsToPropagate: List<FlowError> = allErrors.subList(errorState.propagatedIndex, allErrors.size)
val errorMessages: List<ErrorSessionMessage> = remainingErrorsToPropagate.map(this::createErrorMessageFromError)
return builder {
// If we're errored and propagating do the actual propagation and update the index.
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages)
val newCheckpoint = startingState.checkpoint.copy(
errorState = errorState.copy(propagatedIndex = allErrors.size),
sessions = newSessions
)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions))
}
// If we're errored but not propagating keep processing events.
if (remainingErrorsToPropagate.isNotEmpty() && !errorState.propagating) {
return@builder FlowContinuation.ProcessEvents
}
// If we haven't been removed yet remove the flow.
if (!currentState.isRemoved) {
actions.add(Action.CreateTransaction)
if (currentState.isAnyCheckpointPersisted) {
actions.add(Action.RemoveCheckpoint(context.id))
}
actions.addAll(arrayOf(
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.ReleaseSoftLocks(context.id.uuid),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys)
))
currentState = currentState.copy(
pendingDeduplicationHandlers = emptyList(),
isRemoved = true
)
val removalReason = FlowRemovalReason.ErrorFinish(allErrors)
actions.add(Action.RemoveFlow(context.id, removalReason, currentState))
FlowContinuation.Abort
} else {
// Otherwise keep processing events. This branch happens when there are some outstanding initiating
// sessions that prevent the removal of the flow.
FlowContinuation.ProcessEvents
}
}
}
private fun createErrorMessageFromError(error: FlowError): ErrorSessionMessage {
val exception = error.exception
// If the exception doesn't contain an originalErrorId that means it's a fresh FlowException that should
// propagate to the neighbouring flows. If it has the ID filled in that means it's a rethrown FlowException and
// shouldn't be propagated.
return if (exception is FlowException && exception.originalErrorId == null) {
ErrorSessionMessage(flowException = exception, errorId = error.errorId)
} else {
ErrorSessionMessage(flowException = null, errorId = error.errorId)
}
}
// Buffer error messages in Initiating sessions, return the initialised ones.
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
}
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
} else {
sessionState
}
}
val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && session.errors.isEmpty()) {
session
} else {
null
}
}
return Pair(initiatedSessions, newSessions)
}
}

View File

@ -0,0 +1,410 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowSession
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.statemachine.*
/**
* This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request
* is persisted (unless checkpoint was skipped) and the user-space DB transaction is commited.
*
* Before this transition we either did a checkpoint or the checkpoint was restored from the database.
*/
class StartedFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val started: FlowState.Started
) : Transition {
override fun transition(): TransitionResult {
val flowIORequest = started.flowIORequest
val checkpoint = startingState.checkpoint
val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint)
if (errorsToThrow.isNotEmpty()) {
return TransitionResult(
newState = startingState.copy(isFlowResumed = true),
// throw the first exception. TODO should this aggregate all of them somehow?
actions = listOf(Action.CreateTransaction),
continuation = FlowContinuation.Throw(errorsToThrow[0])
)
}
return when (flowIORequest) {
is FlowIORequest.Send -> sendTransition(flowIORequest)
is FlowIORequest.Receive -> receiveTransition(flowIORequest)
is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest)
is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest)
is FlowIORequest.Sleep -> sleepTransition(flowIORequest)
is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest)
is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition()
is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest)
}
}
private fun waitForSessionConfirmationsTransition(): TransitionResult {
return builder {
if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(Unit)
}
}
}
private fun getFlowInfoTransition(flowIORequest: FlowIORequest.GetFlowInfo): TransitionResult {
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for (session in flowIORequest.sessions) {
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
}
return builder {
// Initialise uninitialised sessions in order to receive the associated FlowInfo. Some or all sessions may
// not be initialised yet.
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
val flowInfoMap = getFlowInfoFromSessions(sessionIdToSession)
if (flowInfoMap == null) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(flowInfoMap)
}
}
}
private fun TransitionBuilder.getFlowInfoFromSessions(sessionIdToSession: Map<SessionId, FlowSessionImpl>): Map<FlowSession, FlowInfo>? {
val checkpoint = currentState.checkpoint
val resultMap = LinkedHashMap<FlowSession, FlowInfo>()
for ((sessionId, session) in sessionIdToSession) {
val sessionState = checkpoint.sessions[sessionId]
if (sessionState is SessionState.Initiated) {
resultMap[session] = sessionState.peerFlowInfo
} else {
return null
}
}
return resultMap
}
private fun sleepTransition(flowIORequest: FlowIORequest.Sleep): TransitionResult {
return builder {
actions.add(Action.SleepUntil(flowIORequest.wakeUpAfter))
resumeFlowLogic(Unit)
}
}
private fun waitForLedgerCommitTransition(flowIORequest: FlowIORequest.WaitForLedgerCommit): TransitionResult {
return if (!startingState.isTransactionTracked) {
TransitionResult(
newState = startingState.copy(isTransactionTracked = true),
actions = listOf(
Action.CreateTransaction,
Action.TrackTransaction(flowIORequest.hash),
Action.CommitTransaction
)
)
} else {
TransitionResult(startingState)
}
}
private fun sendAndReceiveTransition(flowIORequest: FlowIORequest.SendAndReceive): TransitionResult {
val sessionIdToMessage = LinkedHashMap<SessionId, SerializedBytes<Any>>()
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for ((session, message) in flowIORequest.sessionToMessage) {
val sessionId = (session as FlowSessionImpl).sourceSessionId
sessionIdToMessage[sessionId] = message
sessionIdToSession[sessionId] = session
}
return builder {
sendToSessionsTransition(sessionIdToMessage)
if (isErrored()) {
FlowContinuation.ProcessEvents
} else {
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
if (receivedMap == null) {
// We don't yet have the messages, change the suspension to be on Receive
val newIoRequest = FlowIORequest.Receive(flowIORequest.sessionToMessage.keys.toNonEmptySet())
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
flowState = FlowState.Started(newIoRequest, started.frozenFiber)
)
)
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(receivedMap)
}
}
}
}
private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult {
return builder {
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for (session in flowIORequest.sessions) {
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
}
// send initialises to uninitialised sessions
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
if (receivedMap == null) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(receivedMap)
}
}
}
private fun TransitionBuilder.receiveFromSessionsTransition(
sourceSessionIdToSessionMap: Map<SessionId, FlowSessionImpl>
): Map<FlowSession, SerializedBytes<Any>>? {
val checkpoint = currentState.checkpoint
val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null
val resultMap = LinkedHashMap<FlowSession, SerializedBytes<Any>>()
for ((sessionId, message) in pollResult.messages) {
val session = sourceSessionIdToSessionMap[sessionId]!!
resultMap[session] = message
}
currentState = currentState.copy(
checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap)
)
return resultMap
}
data class PollResult(
val messages: Map<SessionId, SerializedBytes<Any>>,
val newSessionMap: SessionMap
)
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
val newSessionMessages = LinkedHashMap(sessions)
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
var someNotFound = false
for (sessionId in sessionIds) {
val sessionState = sessions[sessionId]
when (sessionState) {
is SessionState.Initiated -> {
val messages = sessionState.receivedMessages
if (messages.isEmpty()) {
someNotFound = true
} else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
resultMessages[sessionId] = messages[0].payload
}
}
else -> {
someNotFound = true
}
}
}
return if (someNotFound) {
return null
} else {
PollResult(resultMessages, newSessionMessages)
}
}
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.sessions)
var index = 0
for (sourceSessionId in sourceSessions) {
val sessionState = checkpoint.sessions[sourceSessionId]
if (sessionState == null) {
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
}
if (sessionState !is SessionState.Uninitiated) {
continue
}
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null)
actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
)
}
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
}
private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult {
return builder {
val sessionIdToMessage = flowIORequest.sessionToMessage.mapKeys {
sessionToSessionId(it.key)
}
sendToSessionsTransition(sessionIdToMessage)
if (isErrored()) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(Unit)
}
}
}
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap(checkpoint.sessions)
var index = 0
for ((sourceSessionId, message) in sourceSessionIdToMessage) {
val existingSessionState = checkpoint.sessions[sourceSessionId]
if (existingSessionState == null) {
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
} else {
val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
val _exhaustive = when (existingSessionState) {
is SessionState.Uninitiated -> {
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message)
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
)
Unit
}
is SessionState.Initiating -> {
// We're initiating this session, buffer the message
val newBufferedMessages = existingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
newSessions[sourceSessionId] = existingSessionState.copy(bufferedMessages = newBufferedMessages)
}
is SessionState.Initiated -> {
when (existingSessionState.initiatedState) {
is InitiatedSessionState.Live -> {
val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId))
Unit
}
InitiatedSessionState.Ended -> {
return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId"))
}
}
}
}
}
}
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
}
private fun sessionToSessionId(session: FlowSession): SessionId {
return (session as FlowSessionImpl).sourceSessionId
}
private fun collectErroredSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.flatMap { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Uninitiated -> emptyList()
is SessionState.Initiating -> {
if (sessionState.rejectionError == null) {
emptyList()
} else {
listOf(sessionState.rejectionError.exception)
}
}
is SessionState.Initiated -> sessionState.errors.map(FlowError::exception)
}
}
}
private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> {
return checkpoint.sessions.values.mapNotNull { sessionState ->
(sessionState as? SessionState.Initiating)?.rejectionError?.exception
}
}
private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.mapNotNull { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
if (sessionState.initiatedState is InitiatedSessionState.Ended) {
UnexpectedFlowEndException(
"Tried to access ended session $sessionId",
cause = null,
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
}
}
private fun collectEndedEmptySessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.mapNotNull { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
if (sessionState.initiatedState is InitiatedSessionState.Ended &&
sessionState.receivedMessages.isEmpty()) {
UnexpectedFlowEndException(
"Tried to access ended session $sessionId with empty buffer",
cause = null,
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
}
}
private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List<Throwable> {
return when (flowIORequest) {
is FlowIORequest.Send -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.Receive -> {
val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.SendAndReceive -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.WaitForLedgerCommit -> {
collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint)
}
is FlowIORequest.GetFlowInfo -> {
collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint)
}
is FlowIORequest.Sleep -> {
emptyList()
}
is FlowIORequest.WaitForSessionConfirmations -> {
collectErroredInitiatingSessionErrors(checkpoint)
}
is FlowIORequest.ExecuteAsyncOperation<*> -> {
emptyList()
}
}
}
private fun createInitialSessionMessage(
initiatingSubFlow: SubFlow.Initiating,
sourceSessionId: SessionId,
payload: SerializedBytes<Any>?
): InitialSessionMessage {
return InitialSessionMessage(
initiatorSessionId = sourceSessionId,
// We add additional entropy to add to the initiated side's deduplication seed.
initiationEntropy = context.secureRandom.nextLong(),
initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name,
flowVersion = initiatingSubFlow.flowInfo.flowVersion,
appName = initiatingSubFlow.flowInfo.appName,
firstPayload = payload
)
}
private fun executeAsyncOperation(flowIORequest: FlowIORequest.ExecuteAsyncOperation<*>): TransitionResult {
return builder {
actions.add(Action.ExecuteAsyncOperation(flowIORequest.operation))
FlowContinuation.ProcessEvents
}
}
}

View File

@ -0,0 +1,30 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.*
import net.corda.node.services.statemachine.*
import java.security.SecureRandom
/**
* @property eventQueueSize the size of a flow's event queue. If the queue gets full the thread scheduling the event
* will block. An example scenario would be if the flow is waiting for a lot of messages at once, but is slow at
* processing each.
*/
data class StateMachineConfiguration(
val eventQueueSize: Int
) {
companion object {
val default = StateMachineConfiguration(
eventQueueSize = 16
)
}
}
class StateMachine(
val id: StateMachineRunId,
val configuration: StateMachineConfiguration,
val secureRandom: SecureRandom
) {
fun transition(event: Event, state: StateMachineState): TransitionResult {
return TopLevelTransition(TransitionContext(id, configuration, secureRandom), state, event).transition()
}
}

View File

@ -0,0 +1,233 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.InitiatingFlow
import net.corda.core.internal.FlowIORequest
import net.corda.core.utilities.Try
import net.corda.node.services.statemachine.*
/**
* This is the top level event-handling transition function capable of handling any [Event].
*
* It is a *pure* function taking a state machine state and an event, returning the next state along with a list of IO
* actions to execute.
*/
class TopLevelTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val event: Event
) : Transition {
override fun transition(): TransitionResult {
return when (event) {
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
is Event.Error -> errorTransition(event)
is Event.TransactionCommitted -> transactionCommittedTransition(event)
is Event.SoftShutdown -> softShutdownTransition()
is Event.StartErrorPropagation -> startErrorPropagationTransition()
is Event.EnterSubFlow -> enterSubFlowTransition(event)
is Event.LeaveSubFlow -> leaveSubFlowTransition()
is Event.Suspend -> suspendTransition(event)
is Event.FlowFinish -> flowFinishTransition(event)
is Event.InitiateFlow -> initiateFlowTransition(event)
is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event)
}
}
private fun errorTransition(event: Event.Error): TransitionResult {
return builder {
freshErrorTransition(event.exception)
FlowContinuation.ProcessEvents
}
}
private fun transactionCommittedTransition(event: Event.TransactionCommitted): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
if (currentState.isTransactionTracked &&
checkpoint.flowState is FlowState.Started &&
checkpoint.flowState.flowIORequest is FlowIORequest.WaitForLedgerCommit &&
checkpoint.flowState.flowIORequest.hash == event.transaction.id) {
currentState = currentState.copy(isTransactionTracked = false)
if (isErrored()) {
return@builder FlowContinuation.ProcessEvents
}
resumeFlowLogic(event.transaction)
} else {
freshErrorTransition(UnexpectedEventInState())
FlowContinuation.ProcessEvents
}
}
}
private fun softShutdownTransition(): TransitionResult {
val lastState = startingState.copy(isRemoved = true)
return TransitionResult(
newState = lastState,
actions = listOf(
Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys),
Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState)
),
continuation = FlowContinuation.Abort
)
}
private fun startErrorPropagationTransition(): TransitionResult {
return builder {
val errorState = currentState.checkpoint.errorState
when (errorState) {
ErrorState.Clean -> freshErrorTransition(UnexpectedEventInState())
is ErrorState.Errored -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
errorState = errorState.copy(propagating = true)
)
)
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
}
}
FlowContinuation.ProcessEvents
}
}
private fun enterSubFlowTransition(event: Event.EnterSubFlow): TransitionResult {
return builder {
val subFlow = SubFlow.create(event.subFlowClass)
when (subFlow) {
is Try.Success -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value
)
)
}
is Try.Failure -> {
freshErrorTransition(subFlow.exception)
}
}
FlowContinuation.ProcessEvents
}
}
private fun leaveSubFlowTransition(): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
if (checkpoint.subFlowStack.isEmpty()) {
freshErrorTransition(UnexpectedEventInState())
} else {
currentState = currentState.copy(
checkpoint = checkpoint.copy(
subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList()
)
)
}
FlowContinuation.ProcessEvents
}
}
private fun suspendTransition(event: Event.Suspend): TransitionResult {
return builder {
val newCheckpoint = currentState.checkpoint.copy(
flowState = FlowState.Started(event.ioRequest, event.fiber),
numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1
)
actions.addAll(arrayOf(
Action.PersistCheckpoint(context.id, newCheckpoint),
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
Action.ScheduleEvent(Event.DoRemainingWork)
))
currentState = currentState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false,
isAnyCheckpointPersisted = true
)
FlowContinuation.ProcessEvents
}
}
private fun flowFinishTransition(event: Event.FlowFinish): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
when (checkpoint.errorState) {
ErrorState.Clean -> {
val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers
currentState = currentState.copy(
checkpoint = checkpoint.copy(
numberOfSuspends = checkpoint.numberOfSuspends + 1
),
pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false,
isRemoved = true
)
val allSourceSessionIds = checkpoint.sessions.keys
if (currentState.isAnyCheckpointPersisted) {
actions.add(Action.RemoveCheckpoint(context.id))
}
actions.addAll(arrayOf(
Action.PersistDeduplicationFacts(pendingDeduplicationHandlers),
Action.ReleaseSoftLocks(event.softLocksId),
Action.CommitTransaction,
Action.AcknowledgeMessages(pendingDeduplicationHandlers),
Action.RemoveSessionBindings(allSourceSessionIds),
Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState)
))
sendEndMessages()
// Resume to end fiber
FlowContinuation.Resume(null)
}
is ErrorState.Errored -> {
currentState = currentState.copy(isFlowResumed = false)
actions.add(Action.RollbackTransaction)
FlowContinuation.ProcessEvents
}
}
}
}
private fun TransitionBuilder.sendEndMessages() {
val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state ->
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index)
Action.SendExisting(state.peerParty, message, deduplicationId)
} else {
null
}
}.filterNotNull()
actions.addAll(sendEndMessageActions)
}
private fun initiateFlowTransition(event: Event.InitiateFlow): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
val initiatingSubFlow = getClosestAncestorInitiatingSubFlow(checkpoint)
if (initiatingSubFlow == null) {
freshErrorTransition(IllegalStateException("Tried to initiate in a flow not annotated with @${InitiatingFlow::class.java.simpleName}"))
return@builder FlowContinuation.ProcessEvents
}
val sourceSessionId = SessionId.createRandom(context.secureRandom)
val sessionImpl = FlowSessionImpl(event.party, sourceSessionId)
val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow))
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
FlowContinuation.Resume(sessionImpl)
}
}
private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? {
for (subFlow in checkpoint.subFlowStack.asReversed()) {
if (subFlow is SubFlow.Initiating) {
return subFlow
}
}
return null
}
private fun asyncOperationCompletionTransition(event: Event.AsyncOperationCompletion): TransitionResult {
return builder {
resumeFlowLogic(event.returnValue)
}
}
}

View File

@ -0,0 +1,32 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom
/**
* An interface used to separate out different parts of the state machine transition function.
*/
interface Transition {
/** The context of the transition. */
val context: TransitionContext
/** The state the transition is starting in. */
val startingState: StateMachineState
/** The (almost) pure transition function. The only side-effect we allow is random number generation. */
fun transition(): TransitionResult
/**
* A helper
*/
fun builder(build: TransitionBuilder.() -> FlowContinuation): TransitionResult {
val builder = TransitionBuilder(context, startingState)
val continuation = build(builder)
return TransitionResult(builder.currentState, builder.actions, continuation)
}
}
class TransitionContext(
val id: StateMachineRunId,
val configuration: StateMachineConfiguration,
val secureRandom: SecureRandom
)

View File

@ -0,0 +1,74 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.IdentifiableException
import net.corda.node.services.statemachine.*
// This is a file defining some common utilities for creating state machine transitions.
/**
* A builder that helps creating [Transition]s. This allows for a more imperative style of specifying the transition.
*/
class TransitionBuilder(val context: TransitionContext, initialState: StateMachineState) {
/** The current state machine state of the builder */
var currentState = initialState
/** The list of actions to execute */
val actions = ArrayList<Action>()
/** Check if [currentState] state is errored */
fun isErrored(): Boolean = currentState.checkpoint.errorState is ErrorState.Errored
/**
* Transition the builder into an error state because of a fresh error that happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun freshErrorTransition(error: Throwable) {
val flowError = FlowError(
errorId = (error as? IdentifiableException)?.errorId ?: context.secureRandom.nextLong(),
exception = error
)
errorTransition(flowError)
}
/**
* Transition the builder into an error state because of a list of errors that happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun errorsTransition(errors: List<FlowError>) {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
errorState = currentState.checkpoint.errorState.addErrors(errors)
),
isFlowResumed = false
)
actions.clear()
actions.addAll(arrayOf(
Action.RollbackTransaction,
Action.ScheduleEvent(Event.DoRemainingWork)
))
}
/**
* Transition the builder into an error state because of a non-fresh error has happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun errorTransition(error: FlowError) {
errorsTransition(listOf(error))
}
fun resumeFlowLogic(result: Any?): FlowContinuation {
actions.add(Action.CreateTransaction)
currentState = currentState.copy(isFlowResumed = true)
return FlowContinuation.Resume(result)
}
}
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
class UnexpectedEventInState : IllegalStateException("Unexpected event")

View File

@ -0,0 +1,46 @@
package net.corda.node.services.statemachine.transitions
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.StateMachineState
/**
* A datastructure capturing the intended new state of the flow, the actions to be executed as part of the transition
* and a [FlowContinuation].
*
* Read this datastructure as an instruction to the state machine executor:
* "Transition to [newState] *if* [actions] execute cleanly. If so, use [continuation] to decide what to do next. If
* there was an error it's up to you what to do".
* Also see [net.corda.node.services.statemachine.TransitionExecutorImpl] on how this is interpreted.
*/
data class TransitionResult(
val newState: StateMachineState,
val actions: List<Action> = emptyList(),
val continuation: FlowContinuation = FlowContinuation.ProcessEvents
)
/**
* A datastructure describing what to do after a transition has succeeded.
*/
sealed class FlowContinuation {
/**
* Return to user code with the supplied [result].
*/
data class Resume(val result: Any?) : FlowContinuation() {
override fun toString() = "Resume(result=${result?.javaClass})"
}
/**
* Throw an exception [throwable] in user code.
*/
data class Throw(val throwable: Throwable) : FlowContinuation()
/**
* Keep processing pending events.
*/
object ProcessEvents : FlowContinuation() { override fun toString() = "ProcessEvents" }
/**
* Immediately abort the flow. Note that this does not imply an error condition.
*/
object Abort : FlowContinuation() { override fun toString() = "Abort" }
}

View File

@ -0,0 +1,80 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo
import net.corda.node.services.statemachine.*
/**
* This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and
* initialises the initiated session in case the flow is an initiated one.
*/
class UnstartedFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val unstarted: FlowState.Unstarted
) : Transition {
override fun transition(): TransitionResult {
return builder {
if (!currentState.isAnyCheckpointPersisted && !currentState.isStartIdempotent) {
createInitialCheckpoint()
}
actions.add(Action.SignalFlowHasStarted(context.id))
if (unstarted.flowStart is FlowStart.Initiated) {
initialiseInitiatedSession(unstarted.flowStart)
}
currentState = currentState.copy(isFlowResumed = true)
actions.add(Action.CreateTransaction)
FlowContinuation.Resume(null)
}
}
// Initialise initiated session, store initial payload, send confirmation back.
private fun TransitionBuilder.initialiseInitiatedSession(flowStart: FlowStart.Initiated) {
val initiatingMessage = flowStart.initiatingMessage
val initiatedState = SessionState.Initiated(
peerParty = flowStart.peerSession.counterparty,
initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId),
peerFlowInfo = FlowInfo(
flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion,
appName = initiatingMessage.appName
),
receivedMessages = if (initiatingMessage.firstPayload == null) {
emptyList()
} else {
listOf(DataSessionMessage(initiatingMessage.firstPayload))
},
errors = emptyList()
)
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = mapOf(flowStart.initiatedSessionId to initiatedState)
)
)
actions.add(
Action.SendExisting(
flowStart.peerSession.counterparty,
sessionMessage,
DeduplicationId.createForNormal(currentState.checkpoint, 0)
)
)
}
// Create initial checkpoint and acknowledge triggering messages.
private fun TransitionBuilder.createInitialCheckpoint() {
actions.addAll(arrayOf(
Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers)
))
currentState = currentState.copy(
pendingDeduplicationHandlers = emptyList(),
isAnyCheckpointPersisted = true
)
}
}

View File

@ -197,9 +197,17 @@ class NodeVaultService(
if (!netUpdate.isEmpty()) {
recordUpdate(netUpdate)
mutex.locked {
// flowId required by SoftLockManager to perform auto-registration of soft locks for new states
// flowId was required by SoftLockManager to perform auto-registration of soft locks for new states
val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid
val vaultUpdate = if (uuid != null) netUpdate.copy(flowId = uuid) else netUpdate
if (uuid != null) {
val fungible = netUpdate.produced.filter { it.state.data is FungibleAsset<*> }
if (fungible.isNotEmpty()) {
val stateRefs = fungible.map { it.ref }.toNonEmptySet()
log.trace { "Reserving soft locks for flow id $uuid and states $stateRefs" }
softLockReserve(uuid, stateRefs)
}
}
updatesPublisher.onNext(vaultUpdate)
}
}

View File

@ -1,58 +0,0 @@
package net.corda.node.services.vault
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowLogic
import net.corda.core.node.services.VaultService
import net.corda.core.utilities.*
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.StateMachineManager
import java.util.*
class VaultSoftLockManager private constructor(private val vault: VaultService) {
companion object {
private val log = contextLogger()
@JvmStatic
fun install(vault: VaultService, smm: StateMachineManager) {
val manager = VaultSoftLockManager(vault)
smm.changes.subscribe { change ->
if (change is StateMachineManager.Change.Removed) {
val logic = change.logic
// Don't run potentially expensive query if the flow didn't lock any states:
if ((logic.stateMachine as FlowStateMachineImpl<*>).hasSoftLockedStates) {
manager.unregisterSoftLocks(logic.runId.uuid, logic)
}
}
}
// Discussion
//
// The intent of the following approach is to support what might be a common pattern in a flow:
// 1. Create state
// 2. Do something with state
// without possibility of another flow intercepting the state between 1 and 2,
// since we cannot lock the state before it exists. e.g. Issue and then Move some Cash.
//
// The downside is we could have a long running flow that holds a lock for a long period of time.
// However, the lock can be programmatically released, like any other soft lock,
// should we want a long running flow that creates a visible state mid way through.
vault.rawUpdates.subscribe { (_, produced, flowId) ->
if (flowId != null) {
val fungible = produced.filter { it.state.data is FungibleAsset<*> }
if (fungible.isNotEmpty()) {
manager.registerSoftLocks(flowId, fungible.map { it.ref }.toNonEmptySet())
}
}
}
}
}
private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet<StateRef>) {
log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" }
vault.softLockReserve(flowId, stateRefs)
}
private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) {
log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" }
vault.softLockRelease(flowId)
}
}

View File

@ -0,0 +1,144 @@
package net.corda.node.utilities
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.lang.reflect.Type
import java.time.Instant
/**
* A tree describing the diff between two objects.
*
* For example:
* data class A(val field1: Int, val field2: String, val field3: Unit)
* fun main(args: Array<String>) {
* val someA = A(1, "hello", Unit)
* val someOtherA = A(2, "bello", Unit)
* println(ObjectDiffer.diff(someA, someOtherA))
* }
*
* Will give back Step(branches=[(field1, Last(a=1, b=2)), (field2, Last(a=hello, b=bello))])
*/
sealed class DiffTree {
/**
* Describes a "step" from the object root. It contains a list of field-subtree pairs.
*/
data class Step(val branches: List<Pair<String, DiffTree>>) : DiffTree()
/**
* Describes the leaf of the diff. This is either where the diffing was cutoff (e.g. primitives) or where it failed.
*/
data class Last(val a: Any?, val b: Any?) : DiffTree()
/**
* Flattens the [DiffTree] into a list of [DiffPath]s
*/
fun toPaths(): List<DiffPath> {
return when (this) {
is Step -> branches.flatMap { (step, tree) -> tree.toPaths().map { it.copy(path = listOf(step) + it.path) } }
is Last -> listOf(DiffPath(emptyList(), a, b))
}
}
}
/**
* A diff focused on a single [DiffTree.Last] diff, including the path leading there.
*/
data class DiffPath(
val path: List<String>,
val a: Any?,
val b: Any?
) {
override fun toString(): String {
return "${path.joinToString(".")}: \n $a\n $b\n"
}
}
/**
* This is a very simple differ used to diff objects of any kind, to be used for diagnostic.
*/
object ObjectDiffer {
fun diff(a: Any?, b: Any?): DiffTree? {
if (a == null || b == null) {
if (a == b) {
return null
} else {
return DiffTree.Last(a, b)
}
}
if (a != b) {
if (a.javaClass.isPrimitive || a.javaClass in diffCutoffClasses) {
return DiffTree.Last(a, b)
}
// TODO deduplicate this code
if (a is Map<*, *> && b is Map<*, *>) {
val allKeys = a.keys + b.keys
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
if (branches.isEmpty()) {
return null
} else {
return DiffTree.Step(branches)
}
}
if (a is java.util.Map<*, *> && b is java.util.Map<*, *>) {
val allKeys = a.keySet() + b.keySet()
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
if (branches.isEmpty()) {
return null
} else {
return DiffTree.Step(branches)
}
}
val aFields = getFieldFoci(a)
val bFields = getFieldFoci(b)
try {
if (aFields != bFields) {
return DiffTree.Last(a, b)
} else {
// TODO need to account for cases where the fields don't match up (different subclasses)
val branches = aFields.map { field -> diff(field.get(a), field.get(b))?.let { field.name to it } }.filterNotNull()
if (branches.isEmpty()) {
return DiffTree.Last(a, b)
} else {
return DiffTree.Step(branches)
}
}
} catch (throwable: Exception) {
Exception("Error while diffing $a with $b", throwable).printStackTrace(System.out)
return DiffTree.Last(a, b)
}
} else {
return null
}
}
// List of types to cutoff the diffing at.
private val diffCutoffClasses: Set<Class<*>> = setOf(
String::class.java,
Class::class.java,
Instant::class.java
)
// A type capturing the accessor to a field. This is a separate abstraction to simple reflection as we identify
// getX() and isX() calls as fields as well.
private data class FieldFocus(val name: String, val type: Type, val getter: Method) {
fun get(obj: Any): Any? {
return getter.invoke(obj)
}
}
private fun getFieldFoci(obj: Any) : List<FieldFocus> {
val foci = ArrayList<FieldFocus>()
for (method in obj.javaClass.declaredMethods) {
if (Modifier.isStatic(method.modifiers)) {
continue
}
if (method.name.startsWith("get") && method.name.length > 3 && method.parameterCount == 0) {
val fieldName = method.name[3].toLowerCase() + method.name.substring(4)
foci.add(FieldFocus(fieldName, method.returnType, method))
} else if (method.name.startsWith("is") && method.parameterCount == 0) {
foci.add(FieldFocus(method.name, method.returnType, method))
}
}
return foci
}
}

View File

@ -49,10 +49,10 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var finalDelivery: Message? = null
node2.network.addMessageHandler("test.topic") { msg, _ ->
node2.network.addMessageHandler("test.topic") { msg, _, _ ->
node2.network.send(msg, node3.network.myAddress)
}
node3.network.addMessageHandler("test.topic") { msg, _ ->
node3.network.addMessageHandler("test.topic") { msg, _, _ ->
finalDelivery = msg
}
@ -73,7 +73,7 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } }
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
mockNet.runNetwork(rounds = 1)
assertEquals(3, counter)
@ -89,9 +89,10 @@ class InMemoryMessagingTests {
val node2 = mockNet.createNode()
var received = 0
node1.network.addMessageHandler("valid_message") { _, _ ->
node1.network.addMessageHandler("valid_message") { _, _, _ ->
received++
}
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1))
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1))
node2.network.send(invalidMessage, node1.network.myAddress)

View File

@ -744,6 +744,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
private val database: CordaPersistence,
private val delegate: WritableTransactionStorage
) : WritableTransactionStorage, SingletonSerializeAsToken() {
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return database.transaction {
delegate.trackTransaction(id)
}
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return database.transaction {
delegate.track()

View File

@ -1,6 +1,5 @@
package net.corda.node.services.events
import com.google.common.util.concurrent.MoreExecutors
import com.nhaarman.mockito_kotlin.*
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
@ -22,6 +21,7 @@ import net.corda.testing.internal.doLookup
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices
import net.corda.testing.node.TestClock
import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TestWatcher
@ -42,8 +42,14 @@ open class NodeSchedulerServiceTestBase {
protected val testClock = TestClock(rigorousMock<Clock>().also {
doReturn(mark).whenever(it).instant()
})
private val database = rigorousMock<CordaPersistence>().also {
doAnswer {
val block: DatabaseTransaction.() -> Any? = uncheckedCast(it.arguments[0])
rigorousMock<DatabaseTransaction>().block()
}.whenever(it).transaction(any())
}
protected val flowStarter = rigorousMock<FlowStarter>().also {
doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any())
doReturn(openFuture<FlowStateMachine<*>>()).whenever(it).startFlow(any<FlowLogic<*>>(), any(), any())
}
private val flowsDraingMode = rigorousMock<NodePropertiesStore.FlowsDrainingModeOperations>().also {
doReturn(false).whenever(it).isEnabled()
@ -76,7 +82,7 @@ open class NodeSchedulerServiceTestBase {
protected fun assertStarted(flowLogic: FlowLogic<*>) {
// Like in assertWaitingFor, use timeout to make verify wait as we often race the call to startFlow:
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any())
verify(flowStarter, timeout(5000)).startFlow(same(flowLogic)!!, any(), any())
}
protected fun assertStarted(event: Event) = assertStarted(event.flowLogic)
@ -95,7 +101,6 @@ class NodeSchedulerServiceTest : NodeSchedulerServiceTestBase() {
database,
flowStarter,
servicesForResolution,
serverThread = MoreExecutors.directExecutor(),
flowLogicRefFactory = flowLogicRefFactory,
nodeProperties = nodeProperties,
drainingModePollPeriod = Duration.ofSeconds(5),
@ -209,7 +214,6 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
db,
flowStarter,
servicesForResolution,
serverThread = MoreExecutors.directExecutor(),
flowLogicRefFactory = flowLogicRefFactory,
nodeProperties = nodeProperties,
drainingModePollPeriod = Duration.ofSeconds(5),
@ -262,6 +266,7 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
newDatabase.close()
}
@Ignore("Temporarily")
@Test
fun `test that if schedule is updated then the flow is invoked on the correct schedule`() {
val dataSourceProps = MockServices.makeTestDataSourceProperties()
@ -293,4 +298,4 @@ class NodeSchedulerPersistenceTest : NodeSchedulerServiceTestBase() {
scheduler.join()
database.close()
}
}
}

View File

@ -151,7 +151,9 @@ class ArtemisMessagingTest {
createMessagingServer().start()
val messagingClient = createMessagingClient(platformVersion = platformVersion)
messagingClient.addMessageHandler(TOPIC) { message, _ ->
messagingClient.addMessageHandler(TOPIC) { message, _, handle ->
database.transaction { handle.insideDatabaseTransaction() }
handle.afterDatabaseTransaction() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
receivedMessages.add(message)
}
startNodeMessagingClient()

View File

@ -1,33 +1,40 @@
package net.corda.node.services.persistence
import com.google.common.primitives.Ints
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.core.serialization.serialize
import net.corda.node.internal.configureDatabase
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.internal.LogHelper
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.LogHelper
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import kotlin.streams.toList
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
val checkpoints = mutableListOf<Checkpoint>()
forEach {
checkpoints += it
true
}
return checkpoints
internal fun CheckpointStorage.checkpoints(): List<SerializedBytes<Checkpoint>> {
val checkpoints = getAllCheckpoints().toList()
return checkpoints.map { it.second }
}
class DBCheckpointStorageTests {
private companion object {
val ALICE = TestIdentity(ALICE_NAME, 70).party
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
@ -50,9 +57,9 @@ class DBCheckpointStorageTests {
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
@ -65,12 +72,12 @@ class DBCheckpointStorageTests {
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).isEmpty()
@ -83,12 +90,12 @@ class DBCheckpointStorageTests {
@Test
fun `add and remove checkpoint in single commit operate`() {
val checkpoint = newCheckpoint()
val checkpoint2 = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
val (id2, checkpoint2) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(checkpoint2)
checkpointStorage.removeCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
checkpointStorage.addCheckpoint(id2, checkpoint2)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
@ -101,16 +108,16 @@ class DBCheckpointStorageTests {
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
val (id, firstCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(firstCheckpoint)
checkpointStorage.addCheckpoint(id, firstCheckpoint)
}
val secondCheckpoint = newCheckpoint()
val (id2, secondCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.addCheckpoint(id2, secondCheckpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(firstCheckpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
@ -123,9 +130,9 @@ class DBCheckpointStorageTests {
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
val (id, originalCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(originalCheckpoint)
checkpointStorage.addCheckpoint(id, originalCheckpoint)
}
newCheckpointStorage()
val reconstructedCheckpoint = database.transaction {
@ -135,7 +142,7 @@ class DBCheckpointStorageTests {
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).isEmpty()
@ -148,7 +155,14 @@ class DBCheckpointStorageTests {
}
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
private fun newCheckpoint(): Pair<StateMachineRunId, SerializedBytes<Checkpoint>> {
val id = StateMachineRunId.createRandom()
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {}
}
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, "").getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
}
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import co.paralleluniverse.strands.concurrent.Semaphore
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState
@ -67,10 +68,6 @@ class FlowFrameworkTests {
private lateinit var alice: Party
private lateinit var bob: Party
private fun StartedNode<*>.flushSmm() {
(this.smm as StateMachineManagerImpl).executor.flush()
}
@Before
fun start() {
mockNet = InternalMockNetwork(
@ -109,6 +106,19 @@ class FlowFrameworkTests {
assertThat(flow.lazyTime).isNotNull()
}
class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor {
var thrown = false
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
if (thrown) {
delegate.executeAction(fiber, action)
} else {
thrown = true
throw exception
}
}
}
@Test
fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
@ -116,16 +126,15 @@ class FlowFrameworkTests {
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend")
fiber.actionOnSuspend = {
throw exceptionDuringSuspend
}
val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork()
assertThatThrownBy {
fiber.resultFuture.getOrThrow()
}.isSameAs(exceptionDuringSuspend)
assertThat(aliceNode.smm.allStateMachines).isEmpty()
// Make sure the fiber does actually terminate
assertThat(fiber.isTerminated).isTrue()
assertThat(fiber.state).isEqualTo(Strand.State.WAITING)
}
@Test
@ -148,7 +157,6 @@ class FlowFrameworkTests {
aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing.
bobNode.flushSmm()
bobNode.internals.disableDBCloseOnStop()
bobNode.dispose() // kill receiver
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
@ -174,7 +182,6 @@ class FlowFrameworkTests {
assertEquals(1, bobNode.checkpointStorage.checkpoints().size)
}
// Make sure the add() has finished initial processing.
bobNode.flushSmm()
bobNode.internals.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID
bobNode.dispose()
@ -187,7 +194,6 @@ class FlowFrameworkTests {
val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
// Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
mockNet.runNetwork()
node2b.flushSmm()
fut1.getOrThrow()
val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
@ -216,6 +222,8 @@ class FlowFrameworkTests {
val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
@ -234,9 +242,6 @@ class FlowFrameworkTests {
aliceNode sent normalEnd to charlieNode
//There's no session end from the other flows as they're manually suspended
)
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
}
@Test
@ -338,7 +343,9 @@ class FlowFrameworkTests {
mockNet.runNetwork()
assertThat(erroringFlowSteps.get()).containsExactly(
erroringFlowFuture.getOrThrow()
val flowSteps = erroringFlowSteps.get()
assertThat(flowSteps).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlowFuture.get().exceptionThrown)
)
@ -378,8 +385,8 @@ class FlowFrameworkTests {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
}
assertThat(receivingFiber.isTerminated).isTrue()
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue()
assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING)
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).state).isEqualTo(Strand.State.WAITING)
assertThat(erroringFlowSteps.get()).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlow.get().exceptionThrown)
@ -396,7 +403,7 @@ class FlowFrameworkTests {
}
@Test
fun `FlowException propagated in invocation chain`() {
fun `FlowException only propagated to parent`() {
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
val charlie = charlieNode.info.singleIdentity()
@ -404,9 +411,8 @@ class FlowFrameworkTests {
bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork()
assertThatExceptionOfType(MyFlowException::class.java)
assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
.withMessage("Chain")
}
@Test
@ -558,10 +564,8 @@ class FlowFrameworkTests {
@Test
fun `customised client flow which has annotated @InitiatingFlow again`() {
val result = aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
result.getOrThrow()
aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
}.withMessageContaining(InitiatingFlow::class.java.simpleName)
}
@ -635,24 +639,6 @@ class FlowFrameworkTests {
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
}
@Test
fun `double initiateFlow throws`() {
val future = aliceNode.services.startFlow(DoubleInitiatingFlow()).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(IllegalStateException::class.java)
.isThrownBy { future.getOrThrow() }
.withMessageContaining("Attempted to initiateFlow() twice")
}
@InitiatingFlow
private class DoubleInitiatingFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() {
initiateFlow(ourIdentity)
initiateFlow(ourIdentity)
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers
@ -685,7 +671,6 @@ class FlowFrameworkTests {
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
@ -694,7 +679,7 @@ class FlowFrameworkTests {
private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address)
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
}
}
@ -720,7 +705,7 @@ class FlowFrameworkTests {
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.getMessage().topic == StateMachineManagerImpl.sessionTopic }.map {
return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map {
val from = it.sender.id
val message = it.messageData.deserialize<SessionMessage>()
SessionTransfer(from, sanitise(message), it.recipients)

View File

@ -5,10 +5,8 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.InputStreamAndHash
import net.corda.core.node.ServiceHub
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.api.StartedNodeServices
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
@ -22,7 +20,6 @@ import net.corda.testing.node.StartedMockNode
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -90,12 +87,11 @@ class MaxTransactionSizeTests {
assertEquals(hash1, bigFile1.sha256)
SendLargeTransactionFlow(notary, bob, hash1, hash2, hash3, hash4, verify = false)
}
val ex = assertFailsWith<UnexpectedFlowEndException> {
assertFailsWith<UnexpectedFlowEndException> {
val future = aliceNode.startFlow(flow)
mockNet.runNetwork()
future.getOrThrow()
}
assertThat(ex).hasMessageContaining("Counterparty flow on O=Bob Plc, L=Rome, C=IT had an internal error and has terminated")
}
@StartableByRPC
@ -135,4 +131,4 @@ class MaxTransactionSizeTests {
otherSide.send(Unit)
}
}
}
}

View File

@ -210,7 +210,7 @@ class NotaryServiceTests {
private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) {
aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
val messageData = message.data.deserialize<Any>() as? InitialSessionMessage
val payload = messageData?.firstPayload!!.deserialize()

View File

@ -84,9 +84,12 @@ class VaultSoftLockManagerTest {
private val mockNet = InternalMockNetwork(cordappPackages = listOf(ContractImpl::class.packageName), defaultFactory = { args ->
object : InternalMockNetwork.MockNode(args) {
override fun makeVaultService(keyManagementService: KeyManagementService, services: ServicesForResolution, hibernateConfig: HibernateConfiguration): VaultServiceInternal {
val node = this
val realVault = super.makeVaultService(keyManagementService, services, hibernateConfig)
return object : VaultServiceInternal by realVault {
override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
// Should be called before flow is removed
assertEquals(1, node.started!!.smm.allStateMachines.size)
mockVault.softLockRelease(lockId, stateRefs) // No need to also call the real one for these tests.
}
}

View File

@ -18,6 +18,7 @@ ext['artemis.version'] = "$artemis_version"
ext['hibernate.version'] = "$hibernate_version"
ext['selenium.version'] = "$selenium_version"
ext['jackson.version'] = "$jackson_version"
ext['dropwizard-metrics.version'] = "$metrics_version"
apply plugin: 'java'
apply plugin: 'kotlin'

View File

@ -18,6 +18,7 @@ buildscript {
ext['artemis.version'] = "$artemis_version"
ext['hibernate.version'] = "$hibernate_version"
ext['jackson.version'] = "$jackson_version"
ext['dropwizard-metrics.version'] = "$metrics_version"
apply plugin: 'java'

View File

@ -18,10 +18,8 @@ import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.trace
import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.MessageHandlerRegistration
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.*
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.node.internal.InMemoryMessage
@ -290,9 +288,16 @@ class InMemoryMessagingNetwork private constructor(
private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteSequence,
override val platformVersion: Int,
override val uniqueMessageId: String,
override val uniqueMessageId: DeduplicationId,
override val debugTimestamp: Instant,
override val peer: CordaX500Name) : ReceivedMessage
override val peer: CordaX500Name,
override val senderUUID: String? = null,
override val senderSeqNo: Long? = null,
/** Note this flag is never set in the in memory network. */
override val isSessionInit: Boolean = false) : ReceivedMessage {
override val additionalHeaders: Map<String, String> = emptyMap()
}
/**
* A class that provides an abstraction over the nodes' messaging service that also contains the ability to
@ -319,7 +324,7 @@ class InMemoryMessagingNetwork private constructor(
private val peerHandle: PeerHandle,
private val executor: AffinityExecutor,
private val database: CordaPersistence) : SingletonSerializeAsToken(), InternalMockMessagingService {
private inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
private inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration
@Volatile
private var running = true
@ -330,7 +335,7 @@ class InMemoryMessagingNetwork private constructor(
}
private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<String> = Collections.synchronizedSet(HashSet<String>())
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
override val myAddress: PeerHandle get() = peerHandle
@ -353,7 +358,7 @@ class InMemoryMessagingNetwork private constructor(
}
}
override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
check(running)
val (handler, transfers) = state.locked {
val handler = Handler(topic, callback).apply { handlers.add(this) }
@ -374,7 +379,7 @@ class InMemoryMessagingNetwork private constructor(
state.locked { check(handlers.remove(registration as Handler)) }
}
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, additionalHeaders: Map<String, String>) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
check(running)
msgSend(this, message, target)
if (!sendManuallyPumped) {
@ -400,7 +405,7 @@ class InMemoryMessagingNetwork private constructor(
override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
}
@ -465,7 +470,7 @@ class InMemoryMessagingNetwork private constructor(
database.transaction {
for (handler in deliverTo) {
try {
handler.callback(transfer.toReceivedMessage(), handler)
handler.callback(transfer.toReceivedMessage(), handler, DummyDeduplicationHandler())
} catch (e: Exception) {
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
}
@ -489,5 +494,12 @@ class InMemoryMessagingNetwork private constructor(
message.debugTimestamp,
sender.name)
}
private class DummyDeduplicationHandler : DeduplicationHandler {
override fun afterDatabaseTransaction() {
}
override fun insideDatabaseTransaction() {
}
}
}

View File

@ -2,6 +2,7 @@ package net.corda.testing.node.internal
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.DeduplicationId
import java.time.Instant
/**
@ -9,7 +10,11 @@ import java.time.Instant
*/
data class InMemoryMessage(override val topic: String,
override val data: ByteSequence,
override val uniqueMessageId: String,
override val debugTimestamp: Instant = Instant.now()) : Message {
override val uniqueMessageId: DeduplicationId,
override val debugTimestamp: Instant = Instant.now(),
override val senderUUID: String? = null) : Message {
override val additionalHeaders: Map<String, String> = emptyMap()
override fun toString() = "$topic#${String(data.bytes)}"
}

View File

@ -15,7 +15,6 @@ import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.createDirectories
import net.corda.core.internal.createDirectory
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient
@ -231,7 +230,7 @@ open class InternalMockNetwork(private val cordappPackages: List<String>,
private val entropyRoot = args.entropyRoot
var counter = entropyRoot
override val log get() = staticLog
override val serverThread: AffinityExecutor =
override val serverThread: AffinityExecutor.ServiceAffinityExecutor =
if (mockNet.threadPerNode) {
ServiceAffinityExecutor("Mock node $id thread", 1)
} else {

View File

@ -1,18 +1,25 @@
package net.corda.testing.node.internal
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.WritableTransactionStorage
import rx.Observable
import rx.subjects.PublishSubject
import java.util.HashMap
import java.util.*
/**
* A class which provides an implementation of [WritableTransactionStorage] which is used in [MockServices]
*/
open class MockTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() {
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return txns[id]?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture()
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return DataFeed(txns.values.toList(), _updatesPublisher)
}