State machine rewrite

This commit is contained in:
Andras Slemmer 2017-10-05 10:23:38 +01:00
parent 10635dfbfd
commit 63027a077d
91 changed files with 4760 additions and 1967 deletions

View File

@ -491,6 +491,7 @@ public final class net.corda.core.contracts.StateAndContract extends java.lang.O
public final class net.corda.core.contracts.Structures extends java.lang.Object
@org.jetbrains.annotations.NotNull public static final net.corda.core.crypto.SecureHash hash(net.corda.core.contracts.ContractState)
@org.jetbrains.annotations.NotNull public static final net.corda.core.contracts.Amount withoutIssuer(net.corda.core.contracts.Amount)
public static final int MAX_ISSUER_REF_SIZE = 512
##
@net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.contracts.TimeWindow extends java.lang.Object
public <init>()
@ -827,6 +828,9 @@ public final class net.corda.core.crypto.CryptoUtils extends java.lang.Object
public final boolean verify(byte[])
@org.jetbrains.annotations.NotNull public final net.corda.core.crypto.DigitalSignature withoutKey()
##
public final class net.corda.core.crypto.DummySecureRandom extends java.security.SecureRandom
public static final net.corda.core.crypto.DummySecureRandom INSTANCE
##
public abstract class net.corda.core.crypto.MerkleTree extends java.lang.Object
@org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash getHash()
public static final net.corda.core.crypto.MerkleTree$Companion Companion
@ -1140,11 +1144,14 @@ public static final class net.corda.core.flows.FinalityFlow$Companion extends ja
@org.jetbrains.annotations.NotNull public net.corda.core.utilities.ProgressTracker childProgressTracker()
public static final net.corda.core.flows.FinalityFlow$Companion$NOTARISING INSTANCE
##
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException
@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException implements net.corda.core.flows.IdentifiableException
public <init>()
public <init>(String)
public <init>(String, Throwable)
public <init>(Throwable)
@org.jetbrains.annotations.Nullable public Long getErrorId()
@org.jetbrains.annotations.Nullable public final Long getOriginalErrorId()
public final void setOriginalErrorId(Long)
##
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.FlowInfo extends java.lang.Object
public <init>(int, String)
@ -1209,7 +1216,6 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object
public final void checkFlowPermission(String, Map)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.Nullable public final net.corda.core.flows.FlowStackSnapshot flowStackSnapshot()
@org.jetbrains.annotations.Nullable public static final net.corda.core.flows.FlowLogic getCurrentTopLevel()
@kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.flows.FlowInfo getFlowInfo(net.corda.core.identity.Party)
@org.jetbrains.annotations.NotNull public final org.slf4j.Logger getLogger()
@org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party getOurIdentity()
@org.jetbrains.annotations.NotNull public final net.corda.core.identity.PartyAndCertificate getOurIdentityAndCert()
@ -1219,16 +1225,18 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object
@org.jetbrains.annotations.NotNull public final net.corda.core.internal.FlowStateMachine getStateMachine()
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.flows.FlowSession initiateFlow(net.corda.core.identity.Party)
@co.paralleluniverse.fibers.Suspendable public final void persistFlowStackSnapshot()
@kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public net.corda.core.utilities.UntrustworthyData receive(Class, net.corda.core.identity.Party)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public List receiveAll(Class, List)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAll(Map)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public List receiveAll(Class, List, boolean)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAllMap(Map)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAllMap(Map, boolean)
public final void recordAuditEvent(String, String, Map)
@kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable public void send(net.corda.core.identity.Party, Object)
@kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public net.corda.core.utilities.UntrustworthyData sendAndReceive(Class, net.corda.core.identity.Party, Object)
public final void setStateMachine(net.corda.core.internal.FlowStateMachine)
@co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public static final void sleep(java.time.Duration)
@co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public static final void sleep(java.time.Duration, boolean)
@co.paralleluniverse.fibers.Suspendable public Object subFlow(net.corda.core.flows.FlowLogic)
@org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed track()
@org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed trackStepsTree()
@org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed trackStepsTreeIndex()
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.transactions.SignedTransaction waitForLedgerCommit(net.corda.core.crypto.SecureHash)
@co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.transactions.SignedTransaction waitForLedgerCommit(net.corda.core.crypto.SecureHash, boolean)
public static final net.corda.core.flows.FlowLogic$Companion Companion
@ -1236,6 +1244,7 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object
public static final class net.corda.core.flows.FlowLogic$Companion extends java.lang.Object
@org.jetbrains.annotations.Nullable public final net.corda.core.flows.FlowLogic getCurrentTopLevel()
@co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public final void sleep(java.time.Duration)
@co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public final void sleep(java.time.Duration, boolean)
##
@net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public interface net.corda.core.flows.FlowLogicRef
##
@ -1277,6 +1286,9 @@ public static final class net.corda.core.flows.FlowStackSnapshot$Frame extends j
public int hashCode()
@org.jetbrains.annotations.NotNull public String toString()
##
public interface net.corda.core.flows.IdentifiableException
@javax.annotation.Nullable public Long getErrorId()
##
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.IllegalFlowLogicException extends java.lang.IllegalArgumentException
public <init>(Class, String)
##
@ -1434,9 +1446,10 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec
public int hashCode()
public String toString()
##
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException
public <init>(String)
public <init>(String, Throwable)
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException
public <init>(String, Throwable, long)
@org.jetbrains.annotations.NotNull public Long getErrorId()
public final long getOriginalErrorId()
##
@net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object
public <init>(java.security.PublicKey)
@ -1593,18 +1606,27 @@ public final class net.corda.core.messaging.CordaRPCOpsKt extends java.lang.Obje
@net.corda.core.DoNotImplement public interface net.corda.core.messaging.FlowProgressHandle extends net.corda.core.messaging.FlowHandle
public abstract void close()
@org.jetbrains.annotations.NotNull public abstract rx.Observable getProgress()
@org.jetbrains.annotations.Nullable public abstract net.corda.core.messaging.DataFeed getStepsTreeFeed()
@org.jetbrains.annotations.Nullable public abstract net.corda.core.messaging.DataFeed getStepsTreeIndexFeed()
##
@net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public final class net.corda.core.messaging.FlowProgressHandleImpl extends java.lang.Object implements net.corda.core.messaging.FlowProgressHandle
public <init>(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable)
public <init>(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed)
public <init>(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed, net.corda.core.messaging.DataFeed)
public void close()
@org.jetbrains.annotations.NotNull public final net.corda.core.flows.StateMachineRunId component1()
@org.jetbrains.annotations.NotNull public final net.corda.core.concurrent.CordaFuture component2()
@org.jetbrains.annotations.NotNull public final rx.Observable component3()
@org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed component4()
@org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed component5()
@org.jetbrains.annotations.NotNull public final net.corda.core.messaging.FlowProgressHandleImpl copy(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable)
@org.jetbrains.annotations.NotNull public final net.corda.core.messaging.FlowProgressHandleImpl copy(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed, net.corda.core.messaging.DataFeed)
public boolean equals(Object)
@org.jetbrains.annotations.NotNull public net.corda.core.flows.StateMachineRunId getId()
@org.jetbrains.annotations.NotNull public rx.Observable getProgress()
@org.jetbrains.annotations.NotNull public net.corda.core.concurrent.CordaFuture getReturnValue()
@org.jetbrains.annotations.Nullable public net.corda.core.messaging.DataFeed getStepsTreeFeed()
@org.jetbrains.annotations.Nullable public net.corda.core.messaging.DataFeed getStepsTreeIndexFeed()
public int hashCode()
public String toString()
##
@ -1692,6 +1714,7 @@ public @interface net.corda.core.messaging.RPCReturnsObservables
public final int getPlatformVersion()
public final long getSerial()
public int hashCode()
@org.jetbrains.annotations.NotNull public final net.corda.core.identity.PartyAndCertificate identityAndCertFromX500Name(net.corda.core.identity.CordaX500Name)
@org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party identityFromX500Name(net.corda.core.identity.CordaX500Name)
public final boolean isLegalIdentity(net.corda.core.identity.Party)
public String toString()
@ -1728,6 +1751,7 @@ public @interface net.corda.core.messaging.RPCReturnsObservables
##
@net.corda.core.DoNotImplement public interface net.corda.core.node.StateLoader
@org.jetbrains.annotations.NotNull public abstract net.corda.core.contracts.TransactionState loadState(net.corda.core.contracts.StateRef)
@org.jetbrains.annotations.NotNull public abstract Set loadStates(Set)
##
public final class net.corda.core.node.StatesToRecord extends java.lang.Enum
protected <init>(String, int)
@ -1735,8 +1759,10 @@ public final class net.corda.core.node.StatesToRecord extends java.lang.Enum
public static net.corda.core.node.StatesToRecord[] values()
##
@net.corda.core.DoNotImplement public interface net.corda.core.node.services.AttachmentStorage
public abstract boolean hasAttachment(net.corda.core.crypto.SecureHash)
@org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importAttachment(java.io.InputStream)
@org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importAttachment(java.io.InputStream, String, String)
@org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importOrGetAttachment(java.io.InputStream)
@org.jetbrains.annotations.Nullable public abstract net.corda.core.contracts.Attachment openAttachment(net.corda.core.crypto.SecureHash)
@org.jetbrains.annotations.NotNull public abstract List queryAttachments(net.corda.core.node.services.vault.AttachmentQueryCriteria, net.corda.core.node.services.vault.AttachmentSort)
##
@ -1877,6 +1903,7 @@ public final class net.corda.core.node.services.TimeWindowChecker extends java.l
@org.jetbrains.annotations.Nullable public abstract net.corda.core.transactions.SignedTransaction getTransaction(net.corda.core.crypto.SecureHash)
@org.jetbrains.annotations.NotNull public abstract rx.Observable getUpdates()
@org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.DataFeed track()
@org.jetbrains.annotations.NotNull public abstract net.corda.core.concurrent.CordaFuture trackTransaction(net.corda.core.crypto.SecureHash)
##
@net.corda.core.DoNotImplement public interface net.corda.core.node.services.TransactionVerifierService
@org.jetbrains.annotations.NotNull public abstract net.corda.core.concurrent.CordaFuture verify(net.corda.core.transactions.LedgerTransaction)
@ -1890,6 +1917,9 @@ public final class net.corda.core.node.services.TimeWindowChecker extends java.l
@org.jetbrains.annotations.NotNull public final net.corda.core.crypto.TransactionSignature sign(net.corda.core.crypto.SecureHash)
@org.jetbrains.annotations.NotNull public final net.corda.core.crypto.DigitalSignature$WithKey sign(byte[])
public final void validateTimeWindow(net.corda.core.contracts.TimeWindow)
public static final net.corda.core.node.services.TrustedAuthorityNotaryService$Companion Companion
##
public static final class net.corda.core.node.services.TrustedAuthorityNotaryService$Companion extends java.lang.Object
##
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.node.services.UniquenessException extends net.corda.core.CordaException
public <init>(net.corda.core.node.services.UniquenessProvider$Conflict)
@ -2609,6 +2639,7 @@ public final class net.corda.core.schemas.CommonSchemaV1 extends net.corda.core.
public static final net.corda.core.schemas.CommonSchemaV1 INSTANCE
##
@javax.persistence.MappedSuperclass @net.corda.core.serialization.CordaSerializable public static class net.corda.core.schemas.CommonSchemaV1$FungibleState extends net.corda.core.schemas.PersistentState
public <init>()
public <init>(Set, net.corda.core.identity.AbstractParty, long, net.corda.core.identity.AbstractParty, byte[])
@org.jetbrains.annotations.NotNull public final net.corda.core.identity.AbstractParty getIssuer()
@org.jetbrains.annotations.NotNull public final byte[] getIssuerRef()
@ -2622,6 +2653,7 @@ public final class net.corda.core.schemas.CommonSchemaV1 extends net.corda.core.
public final void setQuantity(long)
##
@javax.persistence.MappedSuperclass @net.corda.core.serialization.CordaSerializable public static class net.corda.core.schemas.CommonSchemaV1$LinearState extends net.corda.core.schemas.PersistentState
public <init>()
public <init>(Set, String, UUID)
public <init>(net.corda.core.contracts.UniqueIdentifier, Set)
@org.jetbrains.annotations.Nullable public final String getExternalId()
@ -3141,6 +3173,7 @@ public static final class net.corda.core.utilities.Id$Companion extends java.lan
@kotlin.jvm.JvmStatic @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Id newInstance(Object, String, java.time.Instant)
##
public final class net.corda.core.utilities.KotlinUtilsKt extends java.lang.Object
@org.jetbrains.annotations.NotNull public static final org.slf4j.Logger contextLogger(Object)
public static final void debug(org.slf4j.Logger, kotlin.jvm.functions.Function0)
public static final int exactAdd(int, int)
public static final long exactAdd(long, long)
@ -3226,6 +3259,7 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends
@net.corda.core.serialization.CordaSerializable public final class net.corda.core.utilities.ProgressTracker extends java.lang.Object
public final void endWithError(Throwable)
@org.jetbrains.annotations.NotNull public final List getAllSteps()
@org.jetbrains.annotations.NotNull public final List getAllStepsLabels()
@org.jetbrains.annotations.NotNull public final rx.Observable getChanges()
@org.jetbrains.annotations.Nullable public final net.corda.core.utilities.ProgressTracker getChildProgressTracker(net.corda.core.utilities.ProgressTracker$Step)
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step getCurrentStep()
@ -3234,12 +3268,16 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends
@org.jetbrains.annotations.Nullable public final net.corda.core.utilities.ProgressTracker getParent()
public final int getStepIndex()
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step[] getSteps()
@org.jetbrains.annotations.NotNull public final rx.Observable getStepsTreeChanges()
public final int getStepsTreeIndex()
@org.jetbrains.annotations.NotNull public final rx.Observable getStepsTreeIndexChanges()
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker getTopLevelTracker()
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step nextStep()
public final void setChildProgressTracker(net.corda.core.utilities.ProgressTracker$Step, net.corda.core.utilities.ProgressTracker)
public final void setCurrentStep(net.corda.core.utilities.ProgressTracker$Step)
##
@net.corda.core.serialization.CordaSerializable public abstract static class net.corda.core.utilities.ProgressTracker$Change extends java.lang.Object
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker getProgressTracker()
##
@net.corda.core.serialization.CordaSerializable public static final class net.corda.core.utilities.ProgressTracker$Change$Position extends net.corda.core.utilities.ProgressTracker$Change
public <init>(net.corda.core.utilities.ProgressTracker, net.corda.core.utilities.ProgressTracker$Step)
@ -3292,6 +3330,10 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends
public interface net.corda.core.utilities.PropertyDelegate
public abstract Object getValue(Object, kotlin.reflect.KProperty)
##
public final class net.corda.core.utilities.SgxSupport extends java.lang.Object
public static final boolean isInsideEnclave()
public static final net.corda.core.utilities.SgxSupport INSTANCE
##
@net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.utilities.Try extends java.lang.Object
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Try combine(net.corda.core.utilities.Try, kotlin.jvm.functions.Function2)
@org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Try flatMap(kotlin.jvm.functions.Function1)
@ -3338,6 +3380,7 @@ public static interface net.corda.core.utilities.UntrustworthyData$Validator ext
@co.paralleluniverse.fibers.Suspendable public abstract Object validate(Object)
##
public final class net.corda.core.utilities.UntrustworthyDataKt extends java.lang.Object
@org.jetbrains.annotations.NotNull public static final net.corda.core.utilities.UntrustworthyData checkPayloadIs(net.corda.core.serialization.SerializedBytes, Class)
public static final Object unwrap(net.corda.core.utilities.UntrustworthyData, kotlin.jvm.functions.Function1)
##
public final class net.corda.core.utilities.UuidGenerator extends java.lang.Object

2
.idea/compiler.xml generated
View File

@ -57,6 +57,8 @@
<module name="finance_integrationTest" target="1.8" />
<module name="finance_main" target="1.8" />
<module name="finance_test" target="1.8" />
<module name="flow-hook_main" target="1.8" />
<module name="flow-hook_test" target="1.8" />
<module name="gradle-plugins-cordapp_main" target="1.8" />
<module name="gradle-plugins-cordapp_test" target="1.8" />
<module name="gradle-plugins-cordform-common_main" target="1.8" />

View File

@ -0,0 +1,16 @@
package net.corda.core.flows;
import javax.annotation.Nullable;
/**
* An exception that may be identified with an ID. If an exception originates in a counter-flow this ID will be
* propagated. This allows correlation of error conditions across different flows.
*/
public interface IdentifiableException {
/**
* @return the ID of the error, or null if the error doesn't have it set (yet).
*/
default @Nullable Long getErrorId() {
return null;
}
}

View File

@ -7,16 +7,27 @@ import net.corda.core.CordaRuntimeException
/**
* Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end.
* The exception will propagate to all counterparty flows and will be thrown on their end the next time they wait on a
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already ended,
* will not receive the exception (if this is required then have them wait for a confirmation message).
* [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already
* ended, will not receive the exception (if this is required then have them wait for a confirmation message).
*
* If the *rethrown* [FlowException] is uncaught in counterparty flows and propagation triggers then the exception is
* downgraded to an [UnexpectedFlowEndException]. This means only immediate counterparty flows will receive information
* about what the exception was.
*
* [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service.
* It is recommended a [FlowLogic] document the [FlowException] types it can throw.
*
* @property originalErrorId the ID backing [getErrorId]. If null it will be set dynamically by the flow framework when
* the exception is handled. This ID is propagated to counterparty flows, even when the [FlowException] is
* downgraded to an [UnexpectedFlowEndException]. This is so the error conditions may be correlated later on.
*/
open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) {
open class FlowException(message: String?, cause: Throwable?) :
CordaException(message, cause), IdentifiableException {
constructor(message: String?) : this(message, null)
constructor(cause: Throwable?) : this(cause?.toString(), cause)
constructor() : this(null, null)
var originalErrorId: Long? = null
override fun getErrorId(): Long? = originalErrorId
}
// DOCEND 1
@ -25,6 +36,7 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m
* that we were not expecting), or the other side had an internal error, or the other side terminated when we
* were waiting for a response.
*/
class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
constructor(msg: String) : this(msg, null)
class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long) :
CordaRuntimeException(message, cause), IdentifiableException {
override fun getErrorId(): Long = originalErrorId
}

View File

@ -5,6 +5,7 @@ import co.paralleluniverse.strands.Strand
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate
import net.corda.core.internal.uncheckedCast
@ -12,10 +13,10 @@ import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.debug
import net.corda.core.utilities.*
import org.slf4j.Logger
import java.time.Duration
import java.time.Instant
@ -75,12 +76,19 @@ abstract class FlowLogic<out T> {
*/
@Suspendable
@JvmStatic
@JvmOverloads
@Throws(FlowException::class)
fun sleep(duration: Duration) {
fun sleep(duration: Duration, maySkipCheckpoint: Boolean = false) {
if (duration.compareTo(Duration.ofMinutes(5)) > 0) {
throw FlowException("Attempt to sleep for longer than 5 minutes is not supported. Consider using SchedulableState.")
}
(Strand.currentStrand() as? FlowStateMachine<*>)?.sleepUntil(Instant.now() + duration) ?: Strand.sleep(duration.toMillis())
val fiber = (Strand.currentStrand() as? FlowStateMachine<*>)
if (fiber == null) {
Strand.sleep(duration.toMillis())
} else {
val request = FlowIORequest.Sleep(wakeUpAfter = fiber.serviceHub.clock.instant() + duration)
fiber.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
}
}
}
@ -92,7 +100,7 @@ abstract class FlowLogic<out T> {
/**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
* only available once the flow has started, which means it cannot be accessed in the constructor. Either
* access this lazily or from inside [call].
*/
val serviceHub: ServiceHub get() = stateMachine.serviceHub
@ -102,7 +110,7 @@ abstract class FlowLogic<out T> {
* that this function does not communicate in itself, the counter-flow will be kicked off by the first send/receive.
*/
@Suspendable
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions)
fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party)
/**
* Specifies the identity, with certificate, to use for this flow. This will be one of the multiple identities that
@ -112,7 +120,10 @@ abstract class FlowLogic<out T> {
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
* is implemented.
*/
val ourIdentityAndCert: PartyAndCertificate get() = stateMachine.ourIdentityAndCert
val ourIdentityAndCert: PartyAndCertificate get() {
return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity }
?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!")
}
/**
* Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node.
@ -122,102 +133,23 @@ abstract class FlowLogic<out T> {
* Note: The current implementation returns the single identity of the node. This will change once multiple identities
* is implemented.
*/
val ourIdentity: Party get() = ourIdentityAndCert.party
/**
* Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it
* provides the necessary information needed for the evolution of flows and enabling backwards compatibility.
*
* This method can be called before any send or receive has been done with [otherParty]. In such a case this will force
* them to start their flow.
*/
@Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING)
@Suspendable
fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false)
/**
* Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response
* is received, which must be of the given [R] type.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* Note that this function is not just a simple send+receive pair: it is more efficient and more correct to
* use this when you expect to do a message swap than do use [send] and then [receive] in turn.
*
* @return an [UntrustworthyData] wrapper around the received object.
*/
@Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING)
inline fun <reified R : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<R> {
return sendAndReceive(R::class.java, otherParty, payload)
}
/**
* Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response
* is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data
* should not be trusted until it's been thoroughly verified for consistency and that all expectations are
* satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code.
*
* Note that this function is not just a simple send+receive pair: it is more efficient and more correct to
* use this when you expect to do a message swap than do use [send] and then [receive] in turn.
*
* @return an [UntrustworthyData] wrapper around the received object.
*/
@Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING)
@Suspendable
open fun <R : Any> sendAndReceive(receiveType: Class<R>, otherParty: Party, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false)
}
/**
* Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received.
*
* Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead.
* It is only intended for the case where the [otherParty] is running a distributed service with an idempotent
* flow which only accepts a single request and sends back a single response e.g. a notary or certain types of
* oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered
* to a different one, so there is no need to wait until the initial node comes back up to obtain a response.
*/
@Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING)
internal inline fun <reified R : Any> sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
}
val ourIdentity: Party get() = stateMachine.ourIdentity
@Suspendable
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
val request = FlowIORequest.SendAndReceive(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = true
)
return stateMachine.suspend(request, maySkipCheckpoint = true)[this]!!.checkPayloadIs(receiveType)
}
@Suspendable
internal inline fun <reified R : Any> FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData<R> {
return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false)
return sendAndReceiveWithRetry(R::class.java, payload)
}
/**
* Suspends until the specified [otherParty] sends us a message of type [R].
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*/
@Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING)
inline fun <reified R : Any> receive(otherParty: Party): UntrustworthyData<R> = receive(R::class.java, otherParty)
/**
* Suspends until the specified [otherParty] sends us a message of type [receiveType].
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @return an [UntrustworthyData] wrapper around the received object.
*/
@Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING)
@Suspendable
open fun <R : Any> receive(receiveType: Class<R>, otherParty: Party): UntrustworthyData<R> {
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false)
}
/** Suspends until a message has been received for each session in the specified [sessions].
*
@ -230,8 +162,14 @@ abstract class FlowLogic<out T> {
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
*/
@Suspendable
open fun receiveAll(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
return stateMachine.receiveAll(sessions, this)
@JvmOverloads
open fun receiveAllMap(sessions: Map<FlowSession, Class<out Any>>, maySkipCheckpoint: Boolean = false): Map<FlowSession, UntrustworthyData<Any>> {
enforceNoPrimitiveInReceive(sessions.values)
val replies = stateMachine.suspend(
ioRequest = FlowIORequest.Receive(sessions.keys.toNonEmptySet()),
maySkipCheckpoint = maySkipCheckpoint
)
return replies.mapValues { (session, payload) -> payload.checkPayloadIs(sessions[session]!!) }
}
/**
@ -246,22 +184,11 @@ abstract class FlowLogic<out T> {
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
*/
@Suspendable
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
@JvmOverloads
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>, maySkipCheckpoint: Boolean = false): List<UntrustworthyData<R>> {
enforceNoPrimitiveInReceive(listOf(receiveType))
enforceNoDuplicates(sessions)
return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions)))
}
/**
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
*
* Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty]
* is offline then message delivery will be retried until it comes back or until the message is older than the
* network's event horizon time.
*/
@Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING)
@Suspendable
open fun send(otherParty: Party, payload: Any) {
stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false)
return castMapValuesToKnownType(receiveAllMap(associateSessionsToReceiveType(receiveType, sessions), maySkipCheckpoint))
}
/**
@ -281,11 +208,8 @@ abstract class FlowLogic<out T> {
open fun <R> subFlow(subLogic: FlowLogic<R>): R {
subLogic.stateMachine = stateMachine
maybeWireUpProgressTracking(subLogic)
if (!subLogic.javaClass.isAnnotationPresent(InitiatingFlow::class.java)) {
subLogic.flowUsedForSessions = flowUsedForSessions
}
logger.debug { "Calling subflow: $subLogic" }
val result = subLogic.call()
val result = stateMachine.subFlow(subLogic)
logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" }
// It's easy to forget this when writing flows so we just step it to the DONE state when it completes.
subLogic.progressTracker?.currentStep = ProgressTracker.DONE
@ -382,7 +306,8 @@ abstract class FlowLogic<out T> {
@Suspendable
@JvmOverloads
fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction {
return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint)
val request = FlowIORequest.WaitForLedgerCommit(hash)
return stateMachine.suspend(request, maySkipCheckpoint = maySkipCheckpoint)
}
/**
@ -423,10 +348,6 @@ abstract class FlowLogic<out T> {
_stateMachine = value
}
// This is the flow used for managing sessions. It defaults to the current flow but if this is an inlined sub-flow
// then it will point to the flow it's been inlined to.
private var flowUsedForSessions: FlowLogic<*> = this
private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) {
val ours = progressTracker
val theirs = subLogic.progressTracker
@ -443,6 +364,11 @@ abstract class FlowLogic<out T> {
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
}
private fun enforceNoPrimitiveInReceive(types: Collection<Class<*>>) {
val primitiveTypes = types.filter { it.isPrimitive }
require(primitiveTypes.isEmpty()) { "Cannot receive primitive type(s) $primitiveTypes" }
}
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
}

View File

@ -0,0 +1,85 @@
package net.corda.core.internal
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowSession
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
import java.time.Instant
/**
* A [FlowIORequest] represents an IO request of a flow when it suspends. It is persisted in checkpoints.
*/
sealed class FlowIORequest<out R : Any> {
/**
* Send messages to sessions.
*
* @property sessionToMessage a map from session to message-to-be-sent.
* @property shouldRetrySend specifies whether the send should be retried.
*/
data class Send(
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
val shouldRetrySend: Boolean
) : FlowIORequest<Unit>() {
override fun toString() = "Send(" +
"sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " +
"shouldRetrySend=$shouldRetrySend" +
")"
}
/**
* Receive messages from sessions.
*
* @property sessions the sessions to receive messages from.
* @return a map from session to received message.
*/
data class Receive(
val sessions: NonEmptySet<FlowSession>
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>()
/**
* Send and receive messages from the specified sessions.
*
* @property sessionToMessage a map from session to message-to-be-sent. The keys also specify which sessions to
* receive from.
* @property shouldRetrySend specifies whether the send should be retried.
* @return a map from session to received message.
*/
data class SendAndReceive(
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>,
val shouldRetrySend: Boolean
) : FlowIORequest<Map<FlowSession, SerializedBytes<Any>>>() {
override fun toString() = "SendAndReceive(${sessionToMessage.mapValues { (key, value) ->
"$key=${value.hash}" }}, shouldRetrySend=$shouldRetrySend)"
}
/**
* Wait for a transaction to be committed to the database.
*
* @property hash the hash of the transaction.
* @return the committed transaction.
*/
data class WaitForLedgerCommit(val hash: SecureHash) : FlowIORequest<SignedTransaction>()
/**
* Get the FlowInfo of the specified sessions.
*
* @property sessions the sessions to get the FlowInfo of.
* @return a map from session to FlowInfo.
*/
data class GetFlowInfo(val sessions: NonEmptySet<FlowSession>) : FlowIORequest<Map<FlowSession, FlowInfo>>()
/**
* Suspend the flow until the specified time.
*
* @property wakeUpAfter the time to sleep until.
*/
data class Sleep(val wakeUpAfter: Instant) : FlowIORequest<Unit>()
/**
* Suspend the flow until all Initiating sessions are confirmed.
*/
object WaitForSessionConfirmations : FlowIORequest<Unit>()
}

View File

@ -1,64 +1,42 @@
package net.corda.core.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.DoNotImplement
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.context.InvocationContext
import net.corda.core.node.ServiceHub
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.UntrustworthyData
import org.slf4j.Logger
import java.time.Instant
/** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */
interface FlowStateMachine<R> {
@DoNotImplement
interface FlowStateMachine<FLOWRETURN> {
@Suspendable
fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo
fun <SUSPENDRETURN : Any> suspend(ioRequest: FlowIORequest<SUSPENDRETURN>, maySkipCheckpoint: Boolean): SUSPENDRETURN
@Suspendable
fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession
@Suspendable
fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
sessionFlow: FlowLogic<*>,
retrySend: Boolean,
maySkipCheckpoint: Boolean): UntrustworthyData<T>
@Suspendable
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData<T>
@Suspendable
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean)
@Suspendable
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction
@Suspendable
fun sleepUntil(until: Instant)
fun initiateFlow(party: Party): FlowSession
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
@Suspendable
fun <SUBFLOWRETURN> subFlow(subFlow: FlowLogic<SUBFLOWRETURN>): SUBFLOWRETURN
@Suspendable
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
@Suspendable
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
val logic: FlowLogic<R>
val logic: FlowLogic<FLOWRETURN>
val serviceHub: ServiceHub
val logger: Logger
val id: StateMachineRunId
val resultFuture: CordaFuture<R>
val resultFuture: CordaFuture<FLOWRETURN>
val context: InvocationContext
val ourIdentityAndCert: PartyAndCertificate
@Suspendable
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
val ourIdentity: Party
}

View File

@ -1,6 +1,7 @@
package net.corda.core.node.services
import net.corda.core.DoNotImplement
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.messaging.DataFeed
import net.corda.core.transactions.SignedTransaction
@ -26,4 +27,9 @@ interface TransactionStorage {
* Returns all currently stored transactions and further fresh ones.
*/
fun track(): DataFeed<List<SignedTransaction>, SignedTransaction>
/**
* Returns a future that completes with the transaction corresponding to [id] once it has been committed
*/
fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction>
}

View File

@ -2,6 +2,9 @@ package net.corda.core.utilities
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowException
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import java.io.Serializable
/**
@ -29,3 +32,15 @@ class UntrustworthyData<out T>(@PublishedApi internal val fromUntrustedWorld: T)
}
inline fun <T, R> UntrustworthyData<T>.unwrap(validator: (T) -> R): R = validator(fromUntrustedWorld)
fun <T : Any> SerializedBytes<Any>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize(this, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IllegalArgumentException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw IllegalArgumentException("We were expecting a ${type.name} but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -85,7 +85,7 @@ class CordappSmokeTest {
class SendBackInitiatorFlowContext(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// An initiated flow calling getFlowContext on its initiator will get the context from the session-init
// An initiated flow calling getFlowInfo on its initiator will get the context from the session-init
val sessionInitContext = otherPartySession.getCounterpartyFlowInfo()
otherPartySession.send(sessionInitContext)
}

View File

@ -14,9 +14,9 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import static net.corda.testing.CoreTestUtils.singleIdentity;
import static net.corda.testing.NodeTestUtils.startFlow;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.Assert.fail;
import static net.corda.testing.NodeTestUtils.startFlow;
public class FlowsInJavaTest {
private final MockNetwork mockNet = new MockNetwork();
@ -62,9 +62,8 @@ public class FlowsInJavaTest {
fail("ExecutionException should have been thrown");
} catch (ExecutionException e) {
assertThat(e.getCause())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("primitive")
.hasMessageContaining(receiveType.getName());
.hasMessageContaining(Primitives.unwrap(receiveType).getName());
}
}
@ -102,6 +101,18 @@ public class FlowsInJavaTest {
}
}
@InitiatedBy(PrimitiveReceiveFlow.class)
private static class PrimitiveSendFlow extends FlowLogic<Void> {
public PrimitiveSendFlow(FlowSession session) {
}
@Suspendable
@Override
public Void call() throws FlowException {
return null;
}
}
@InitiatingFlow
private static class PrimitiveReceiveFlow extends FlowLogic<Void> {
private final Party otherParty;

View File

@ -79,7 +79,7 @@ infix fun <T : Any> KClass<T>.from(session: FlowSession): Pair<FlowSession, Clas
fun FlowLogic<*>.receiveAll(session: Pair<FlowSession, Class<out Any>>, vararg sessions: Pair<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
val allSessions = arrayOf(session, *sessions)
allSessions.enforceNoDuplicates()
return receiveAll(mapOf(*allSessions))
return receiveAllMap(mapOf(*allSessions))
}
/**

View File

@ -409,68 +409,6 @@ Our side of the flow must mirror these calls. We could do this as follows:
:end-before: DOCEND 08
:dedent: 12
Why sessions?
^^^^^^^^^^^^^
Before ``FlowSession`` s were introduced the send/receive API looked a bit different. They were functions on
``FlowLogic`` and took the address ``Party`` as argument. The platform internally maintained a mapping from ``Party`` to
session, hiding sessions from the user completely.
Although this is a convenient API it introduces subtle issues where a message that was originally meant for a specific
session may end up in another.
Consider the following contrived example using the old ``Party`` based API:
.. container:: codeset
.. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt
:language: kotlin
:start-after: DOCSTART LaunchSpaceshipFlow
:end-before: DOCEND LaunchSpaceshipFlow
The intention of the flows is very clear: LaunchSpaceshipFlow asks the president whether a spaceship should be launched.
It is expecting a boolean reply. The president in return first tells the secretary that they need coffee, which is also
communicated with a boolean. Afterwards the president replies to the launcher that they don't want to launch.
However the above can go horribly wrong when the ``launcher`` happens to be the same party ``getSecretary`` returns. In
this case the boolean meant for the secretary will be received by the launcher!
This indicates that ``Party`` is not a good identifier for the communication sequence, and indeed the ``Party`` based
API may introduce ways for an attacker to fish for information and even trigger unintended control flow like in the
above case.
Hence we introduced ``FlowSession``, which identifies the communication sequence. With ``FlowSession`` s the above set
of flows would look like this:
.. container:: codeset
.. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt
:language: kotlin
:start-after: DOCSTART LaunchSpaceshipFlowCorrect
:end-before: DOCEND LaunchSpaceshipFlowCorrect
Note how the president is now explicit about which session it wants to send to.
Porting from the old Party-based API
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the old API the first ``send`` or ``receive`` to a ``Party`` was the one kicking off the counter-flow. This is now
explicit in the ``initiateFlow`` function call. To port existing code:
.. container:: codeset
.. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt
:language: kotlin
:start-after: DOCSTART FlowSession porting
:end-before: DOCEND FlowSession porting
:dedent: 8
.. literalinclude:: ../../docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java
:language: java
:start-after: DOCSTART FlowSession porting
:end-before: DOCEND FlowSession porting
:dedent: 12
Subflows
--------

View File

@ -575,13 +575,6 @@ public class FlowCookbookJava {
SignedTransaction notarisedTx2 = subFlow(new FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker()));
// DOCEND 10
// DOCSTART FlowSession porting
send(regulator, new Object()); // Old API
// becomes
FlowSession session = initiateFlow(regulator);
session.send(new Object());
// DOCEND FlowSession porting
return null;
}
}

View File

@ -553,13 +553,6 @@ class InitiatorFlow(val arg1: Boolean, val arg2: Int, private val counterparty:
val additionalParties: Set<Party> = setOf(regulator)
val notarisedTx2: SignedTransaction = subFlow(FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker()))
// DOCEND 10
// DOCSTART FlowSession porting
send(regulator, Any()) // Old API
// becomes
val session = initiateFlow(regulator)
session.send(Any())
// DOCEND FlowSession porting
}
}

View File

@ -1,99 +0,0 @@
package net.corda.docs
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.utilities.unwrap
// DOCSTART LaunchSpaceshipFlow
@InitiatingFlow
class LaunchSpaceshipFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val shouldLaunchSpaceship = receive<Boolean>(getPresident()).unwrap { it }
if (shouldLaunchSpaceship) {
launchSpaceship()
}
}
fun launchSpaceship() {
}
fun getPresident(): Party {
TODO()
}
}
@InitiatedBy(LaunchSpaceshipFlow::class)
@InitiatingFlow
class PresidentSpaceshipFlow(val launcher: Party) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val needCoffee = true
send(getSecretary(), needCoffee)
val shouldLaunchSpaceship = false
send(launcher, shouldLaunchSpaceship)
}
fun getSecretary(): Party {
TODO()
}
}
@InitiatedBy(PresidentSpaceshipFlow::class)
class SecretaryFlow(val president: Party) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// ignore
}
}
// DOCEND LaunchSpaceshipFlow
// DOCSTART LaunchSpaceshipFlowCorrect
@InitiatingFlow
class LaunchSpaceshipFlowCorrect : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val presidentSession = initiateFlow(getPresident())
val shouldLaunchSpaceship = presidentSession.receive<Boolean>().unwrap { it }
if (shouldLaunchSpaceship) {
launchSpaceship()
}
}
fun launchSpaceship() {
}
fun getPresident(): Party {
TODO()
}
}
@InitiatedBy(LaunchSpaceshipFlowCorrect::class)
@InitiatingFlow
class PresidentSpaceshipFlowCorrect(val launcherSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val needCoffee = true
val secretarySession = initiateFlow(getSecretary())
secretarySession.send(needCoffee)
val shouldLaunchSpaceship = false
launcherSession.send(shouldLaunchSpaceship)
}
fun getSecretary(): Party {
TODO()
}
}
@InitiatedBy(PresidentSpaceshipFlowCorrect::class)
class SecretaryFlowCorrect(val presidentSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// ignore
}
}
// DOCEND LaunchSpaceshipFlowCorrect

View File

@ -10,11 +10,13 @@ import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipients
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MessagingServiceSpy
import net.corda.testing.node.MockNetwork
@ -84,11 +86,11 @@ class TutorialMockNetwork {
nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) {
val messageData = message.data.deserialize<Any>()
if (messageData is SessionData && messageData.payload.deserialize() == 1) {
val alteredMessageData = SessionData(messageData.recipientSessionId, 99.serialize()).serialize().bytes
messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topicSession, alteredMessageData, message.uniqueMessageId), target, retryId)
val messageData = message.data.deserialize<Any>() as? ExistingSessionMessage
val payload = messageData?.payload
if (payload is DataSessionMessage && payload.payload.deserialize() == 1) {
val alteredMessageData = messageData.copy(payload = payload.copy(99.serialize())).serialize().bytes
messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData), message.uniqueMessageId), target, retryId)
} else {
messagingService.send(message, target, retryId)
}

View File

@ -38,8 +38,8 @@ or if the semantics of a particular receive changes.
The ``InitiatingFlow`` annotation (see :doc:`flow-state-machine` for more information on the flow annotations) has a ``version``
property, which if not specified defaults to 1. This flow version is included in the flow session handshake and exposed
to both parties in the communication via ``FlowLogic.getFlowContext``. This takes in a ``Party`` and will return a
``FlowContext`` object which describes the flow running on the other side. In particular it has the ``flowVersion`` property
to both parties in the communication via ``FlowLogic.getFlowInfo``. This takes in a ``Party`` and will return a
``FlowInfo`` object which describes the flow running on the other side. In particular it has the ``flowVersion`` property
which can be used to programmatically evolve flows across versions.
.. container:: codeset
@ -48,7 +48,7 @@ which can be used to programmatically evolve flows across versions.
@Suspendable
override fun call() {
val flowVersionOfOtherParty = getFlowContext(otherParty).flowVersion
val flowVersionOfOtherParty = getFlowInfo(otherParty).flowVersion
val receivedString = if (flowVersionOfOtherParty == 1) {
receive<Int>(otherParty).unwrap { it.toString() }
} else {
@ -63,7 +63,7 @@ running the older flow (or rather older CorDapps containing the older flow).
.. warning:: It's important that ``InitiatingFlow.version`` be incremented each time the flow protocol changes in an
incompatible way.
``FlowContext`` also has ``appName`` which is the name of the CorDapp hosting the flow. This can be used to determine
``FlowInfo`` also has ``appName`` which is the name of the CorDapp hosting the flow. This can be used to determine
implementation details of the CorDapp. See :doc:`cordapp-build-systems` for more information on the CorDapp filename.
.. note:: Currently changing any of the properties of a ``CordaSerializable`` type is also backwards incompatible and

View File

@ -0,0 +1,53 @@
buildscript {
// For sharing constants between builds
Properties constants = new Properties()
file("$projectDir/../../constants.properties").withInputStream { constants.load(it) }
ext.kotlin_version = constants.getProperty("kotlinVersion")
ext.javaassist_version = "3.12.1.GA"
repositories {
mavenLocal()
mavenCentral()
jcenter()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
repositories {
mavenLocal()
mavenCentral()
jcenter()
}
apply plugin: 'kotlin'
apply plugin: 'kotlin-kapt'
apply plugin: 'idea'
description 'A javaagent to allow hooking into Kryo'
dependencies {
compile project(':node')
compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
compile "javassist:javassist:$javaassist_version"
compile "com.esotericsoftware:kryo:4.0.0"
compile "co.paralleluniverse:quasar-core:$quasar_version:jdk8"
}
jar {
archiveName = "${project.name}.jar"
manifest {
attributes(
'Premain-Class': 'net.corda.flowhook.FlowHookAgent',
'Can-Redefine-Classes': 'true',
'Can-Retransform-Classes': 'true',
'Can-Set-Native-Method-Prefix': 'true',
'Implementation-Title': "FlowHook",
'Implementation-Version': rootProject.version
)
}
}

View File

@ -0,0 +1,209 @@
package net.corda.flowhook
import co.paralleluniverse.fibers.Fiber
import net.corda.core.internal.uncheckedCast
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import java.sql.Connection
import java.time.Instant
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.thread
/**
* This is a debugging helper class that dumps the map of Fiber->DB connection, or more precisely, the
* Fiber->(DB tx -> DB connection) map, as there may be multiple transactions per fiber.
*/
data class MonitorEvent(val type: MonitorEventType, val keys: List<Any>, val extra: Any? = null)
data class FullMonitorEvent(val timestamp: Instant, val trace: List<StackTraceElement>, val event: MonitorEvent) {
override fun toString() = event.toString()
}
enum class MonitorEventType {
TransactionCreated,
ConnectionRequested,
ConnectionAcquired,
ConnectionReleased,
FiberStarted,
FiberParking,
FiberException,
FiberResumed,
FiberEnded,
ExecuteTransition
}
object FiberMonitor {
private val log = contextLogger()
private val jobQueue = LinkedBlockingQueue<Job>()
private val started = AtomicBoolean(false)
private var trackerThread: Thread? = null
val correlator = MonitorEventCorrelator()
sealed class Job {
data class NewEvent(val event: FullMonitorEvent) : Job()
object Finish : Job()
}
fun newEvent(event: MonitorEvent) {
if (trackerThread != null) {
jobQueue.add(Job.NewEvent(FullMonitorEvent(Instant.now(), Exception().stackTrace.toList(), event)))
}
}
fun start() {
if (started.compareAndSet(false, true)) {
require(trackerThread == null)
trackerThread = thread(name = "Fiber monitor", isDaemon = true) {
while (true) {
val job = jobQueue.poll(1, TimeUnit.SECONDS)
when (job) {
is Job.NewEvent -> processEvent(job)
Job.Finish -> return@thread
}
}
}
}
}
private fun processEvent(job: Job.NewEvent) {
correlator.addEvent(job.event)
checkLeakedTransactions(job.event.event)
checkLeakedConnections(job.event.event)
}
inline fun <reified R, A : Any> R.getField(name: String): A {
val field = R::class.java.getDeclaredField(name)
field.isAccessible = true
return uncheckedCast(field.get(this))
}
fun <A : Any> Any.getFieldFromObject(name: String): A {
val field = javaClass.getDeclaredField(name)
field.isAccessible = true
return uncheckedCast(field.get(this))
}
fun getThreadLocalMapEntryValues(locals: Any): List<Any> {
val table: Array<Any?> = locals.getFieldFromObject("table")
return table.mapNotNull { it?.getFieldFromObject<Any>("value") }
}
fun getStashedThreadLocals(fiber: Fiber<*>): List<Any> {
val fiberLocals: Any = fiber.getField("fiberLocals")
val inheritableFiberLocals: Any = fiber.getField("inheritableFiberLocals")
return getThreadLocalMapEntryValues(fiberLocals) + getThreadLocalMapEntryValues(inheritableFiberLocals)
}
fun getTransactionStack(transaction: DatabaseTransaction): List<DatabaseTransaction> {
val transactions = ArrayList<DatabaseTransaction>()
var currentTransaction: DatabaseTransaction? = transaction
while (currentTransaction != null) {
transactions.add(currentTransaction)
currentTransaction = currentTransaction.outerTransaction
}
return transactions
}
private fun checkLeakedTransactions(event: MonitorEvent) {
if (event.type == MonitorEventType.FiberParking) {
val fiber = event.keys.mapNotNull { it as? Fiber<*> }.first()
val threadLocals = getStashedThreadLocals(fiber)
val transactions = threadLocals.mapNotNull { it as? DatabaseTransaction }.flatMap { getTransactionStack(it) }
val leakedTransactions = transactions.filter { it.connectionCreated && !it.connection.isClosed }
if (leakedTransactions.isNotEmpty()) {
log.warn("Leaked open database transactions on yield $leakedTransactions")
}
}
}
private fun checkLeakedConnections(event: MonitorEvent) {
if (event.type == MonitorEventType.FiberParking) {
val events = correlator.events[event.keys[0]]!!
val acquiredConnections = events.mapNotNullTo(HashSet()) {
if (it.event.type == MonitorEventType.ConnectionAcquired) {
it.event.keys.mapNotNull { it as? Connection }.first()
} else {
null
}
}
val releasedConnections = events.mapNotNullTo(HashSet()) {
if (it.event.type == MonitorEventType.ConnectionReleased) {
it.event.keys.mapNotNull { it as? Connection }.first()
} else {
null
}
}
val leakedConnections = (acquiredConnections - releasedConnections).filter { !it.isClosed }
if (leakedConnections.isNotEmpty()) {
log.warn("Leaked open connections $leakedConnections")
}
}
}
}
class MonitorEventCorrelator {
private val _events = HashMap<Any, ArrayList<FullMonitorEvent>>()
val events: Map<Any, ArrayList<FullMonitorEvent>> get() = _events
fun getUnique() = events.values.toSet().associateBy { it.flatMap { it.event.keys }.toSet() }
fun getByType() = events.entries.groupBy { it.key.javaClass }
fun addEvent(fullMonitorEvent: FullMonitorEvent) {
val list = link(fullMonitorEvent.event.keys)
list.add(fullMonitorEvent)
for (key in fullMonitorEvent.event.keys) {
_events[key] = list
}
}
fun link(keys: List<Any>): ArrayList<FullMonitorEvent> {
val eventLists = HashSet<ArrayList<FullMonitorEvent>>()
for (key in keys) {
val list = _events[key]
if (list != null) {
eventLists.add(list)
}
}
return when {
eventLists.isEmpty() -> ArrayList()
eventLists.size == 1 -> eventLists.first()
else -> mergeAll(eventLists)
}
}
fun mergeAll(lists: Collection<List<FullMonitorEvent>>): ArrayList<FullMonitorEvent> {
return lists.fold(ArrayList()) { merged, next -> merge(merged, next) }
}
fun merge(a: List<FullMonitorEvent>, b: List<FullMonitorEvent>): ArrayList<FullMonitorEvent> {
val merged = ArrayList<FullMonitorEvent>()
var aIndex = 0
var bIndex = 0
while (true) {
if (aIndex >= a.size) {
merged.addAll(b.subList(bIndex, b.size))
return merged
}
if (bIndex >= b.size) {
merged.addAll(a.subList(aIndex, a.size))
return merged
}
val aElem = a[aIndex]
val bElem = b[bIndex]
if (aElem.timestamp < bElem.timestamp) {
merged.add(aElem)
aIndex++
} else {
merged.add(bElem)
bIndex++
}
}
}
}

View File

@ -0,0 +1,15 @@
package net.corda.flowhook
import java.lang.instrument.Instrumentation
@Suppress("UNUSED")
class FlowHookAgent {
companion object {
@JvmStatic
fun premain(argumentsString: String?, instrumentation: Instrumentation) {
FiberMonitor.start()
instrumentation.addTransformer(Hooker(FlowHookContainer))
}
}
}

View File

@ -0,0 +1,104 @@
package net.corda.flowhook
import co.paralleluniverse.fibers.Fiber
import net.corda.node.services.statemachine.ActionExecutor
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.FlowFiber
import net.corda.node.services.statemachine.StateMachineState
import net.corda.node.services.statemachine.transitions.TransitionResult
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import rx.subjects.Subject
import java.sql.Connection
@Suppress("UNUSED")
object FlowHookContainer {
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber")
fun park() {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParking, keys = listOf(Fiber.currentFiber())))
}
@JvmStatic
@Hook("net.corda.node.services.statemachine.FlowStateMachineImpl")
fun run() {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberStarted, keys = listOf(Fiber.currentFiber())))
}
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber")
fun onCompleted() {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberEnded, keys = listOf(Fiber.currentFiber())))
}
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber")
fun onException(exception: Throwable) {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberException, keys = listOf(Fiber.currentFiber()), extra = exception))
}
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber")
fun onResumed() {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberResumed, keys = listOf(Fiber.currentFiber())))
}
@JvmStatic
@Hook("net.corda.node.utilities.DatabaseTransaction", passThis = true, position = HookPosition.After)
fun DatabaseTransaction(
transaction: Any,
isolation: Int,
threadLocal: ThreadLocal<*>,
transactionBoundaries: Subject<*, *>,
cordaPersistence: CordaPersistence
) {
val keys = ArrayList<Any>().apply {
add(transaction)
Fiber.currentFiber()?.let { add(it) }
}
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.TransactionCreated, keys = keys))
}
@JvmStatic
@Hook("com.zaxxer.hikari.HikariDataSource")
fun getConnection(): (Connection) -> Unit {
val transactionOrThread = currentTransactionOrThread()
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionRequested, keys = listOf(transactionOrThread)))
return { connection ->
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionAcquired, keys = listOf(transactionOrThread, connection)))
}
}
@JvmStatic
@Hook("com.zaxxer.hikari.pool.ProxyConnection", passThis = true, position = HookPosition.After)
fun close(connection: Any) {
connection as Connection
val transactionOrThread = currentTransactionOrThread()
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionReleased, keys = listOf(transactionOrThread, connection)))
}
@JvmStatic
@Hook("net.corda.node.services.statemachine.TransitionExecutorImpl")
fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
) {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ExecuteTransition, keys = listOf(fiber), extra = object {
val previousState = previousState
val event = event
val transition = transition
}))
}
private fun currentTransactionOrThread(): Any {
return try {
DatabaseTransactionManager.currentOrNull()
} catch (exception: IllegalStateException) {
null
} ?: Thread.currentThread()
}
}

View File

@ -0,0 +1,130 @@
package net.corda.flowhook
import javassist.ClassPool
import javassist.CtBehavior
import javassist.CtClass
import java.io.ByteArrayInputStream
import java.lang.instrument.ClassFileTransformer
import java.lang.reflect.Method
import java.security.ProtectionDomain
class Hooker(hookContainer: Any) : ClassFileTransformer {
private val classPool = ClassPool.getDefault()
private val hooks = createHooks(hookContainer)
private fun createHooks(hookContainer: Any): Hooks {
val hooks = HashMap<String, HashMap<Signature, Pair<Method, Hook>>>()
for (method in hookContainer.javaClass.methods) {
val hookAnnotation = method.getAnnotation(Hook::class.java)
if (hookAnnotation != null) {
val signature = if (hookAnnotation.passThis) {
if (method.parameterTypes.isEmpty() || method.parameterTypes[0] != Any::class.java) {
println("Method should accept an object as first parameter for 'this' $method")
continue
}
Signature(method.name, method.parameterTypes.toList().drop(1).map { it.canonicalName })
} else {
Signature(method.name, method.parameterTypes.map { it.canonicalName })
}
hooks.getOrPut(hookAnnotation.clazz) { HashMap() }.put(signature, Pair(method, hookAnnotation))
}
}
return hooks
}
override fun transform(
loader: ClassLoader?,
className: String,
classBeingRedefined: Class<*>?,
protectionDomain: ProtectionDomain?,
classfileBuffer: ByteArray
): ByteArray? {
if (className.startsWith("java") || className.startsWith("sun") || className.startsWith("javassist") || className.startsWith("kotlin")) {
return null
}
return try {
val clazz = classPool.makeClass(ByteArrayInputStream(classfileBuffer))
instrumentClass(clazz)?.toBytecode()
} catch (throwable: Throwable) {
println("SOMETHING WENT WRONG")
throwable.printStackTrace(System.out)
null
}
}
private fun instrumentClass(clazz: CtClass): CtClass? {
val hookMethods = hooks[clazz.name] ?: return null
val usedHookMethods = HashSet<Method>()
var isAnyInstrumented = false
for (method in clazz.declaredBehaviors) {
val hookMethod = instrumentBehaviour(method, hookMethods)
if (hookMethod != null) {
isAnyInstrumented = true
usedHookMethods.add(hookMethod)
}
}
val unusedHookMethods = hookMethods.values.mapTo(HashSet()) { it.first } - usedHookMethods
if (unusedHookMethods.isNotEmpty()) {
println("Unused hook methods $unusedHookMethods")
}
return if (isAnyInstrumented) {
clazz
} else {
null
}
}
private fun instrumentBehaviour(method: CtBehavior, methodHooks: MethodHooks): Method? {
val signature = Signature(method.name, method.parameterTypes.map { it.name })
val (hookMethod, annotation) = methodHooks[signature] ?: return null
val invocationString = if (annotation.passThis) {
"${hookMethod.declaringClass.canonicalName}.${hookMethod.name}(this, \$\$)"
} else {
"${hookMethod.declaringClass.canonicalName}.${hookMethod.name}(\$\$)"
}
val overriddenPosition = if (method.methodInfo.isConstructor && annotation.passThis && annotation.position == HookPosition.Before) {
println("passThis=true and position=${HookPosition.Before} for a constructor. " +
"You can only inspect 'this' at the end of the constructor! Hooking *after*.. $method")
HookPosition.After
} else {
annotation.position
}
val insertHook: (CtBehavior.(code: String) -> Unit) = when (overriddenPosition) {
HookPosition.Before -> CtBehavior::insertBefore
HookPosition.After -> CtBehavior::insertAfter
}
when {
Function0::class.java.isAssignableFrom(hookMethod.returnType) -> {
method.addLocalVariable("after", classPool.get("kotlin.jvm.functions.Function0"))
method.insertHook("after = $invocationString;")
method.insertAfter("after.invoke();")
}
Function1::class.java.isAssignableFrom(hookMethod.returnType) -> {
method.addLocalVariable("after", classPool.get("kotlin.jvm.functions.Function1"))
method.insertHook("after = $invocationString;")
method.insertAfter("after.invoke((\$w)\$_);")
}
else -> {
method.insertHook("$invocationString;")
}
}
return hookMethod
}
}
enum class HookPosition {
Before,
After
}
@Target(AnnotationTarget.FUNCTION)
annotation class Hook(val clazz: String, val position: HookPosition = HookPosition.Before, val passThis: Boolean = false)
private data class Signature(val functionName: String, val parameterTypes: List<String>)
private typealias MethodHooks = Map<Signature, Pair<Method, Hook>>
private typealias Hooks = Map<String, MethodHooks>

View File

@ -215,3 +215,15 @@ fun <T : Any> rx.Observable<T>.wrapWithDatabaseTransaction(db: CordaPersistence?
}
}
}
fun parserTransactionIsolationLevel(property: String?): Int =
when (property) {
"none" -> Connection.TRANSACTION_NONE
"readUncommitted" -> Connection.TRANSACTION_READ_UNCOMMITTED
"readCommitted" -> Connection.TRANSACTION_READ_COMMITTED
"repeatableRead" -> Connection.TRANSACTION_REPEATABLE_READ
"serializable" -> Connection.TRANSACTION_SERIALIZABLE
else -> {
Connection.TRANSACTION_REPEATABLE_READ
}
}

View File

@ -14,8 +14,12 @@ class DatabaseTransaction(
) {
val id: UUID = UUID.randomUUID()
private var _connectionCreated = false
val connectionCreated get() = _connectionCreated
val connection: Connection by lazy(LazyThreadSafetyMode.NONE) {
cordaPersistence.dataSource.connection.apply {
cordaPersistence.dataSource.connection
.apply {
_connectionCreated = true
autoCommit = false
transactionIsolation = isolation
}
@ -30,20 +34,22 @@ class DatabaseTransaction(
val session: Session by sessionDelegate
private lateinit var hibernateTransaction: Transaction
private val outerTransaction: DatabaseTransaction? = threadLocal.get()
val outerTransaction: DatabaseTransaction? = threadLocal.get()
fun commit() {
if (sessionDelegate.isInitialized()) {
hibernateTransaction.commit()
}
if (_connectionCreated) {
connection.commit()
}
}
fun rollback() {
if (sessionDelegate.isInitialized() && session.isOpen) {
session.clear()
}
if (!connection.isClosed) {
if (_connectionCreated && !connection.isClosed) {
connection.rollback()
}
}
@ -52,7 +58,9 @@ class DatabaseTransaction(
if (sessionDelegate.isInitialized() && session.isOpen) {
session.close()
}
if (_connectionCreated) {
connection.close()
}
threadLocal.set(outerTransaction)
if (outerTransaction == null) {
transactionBoundaries.onNext(DatabaseTransactionManager.Boundary(id))

View File

@ -3,7 +3,7 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.*
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.Envelope
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
@ -47,17 +47,17 @@ class ListsSerializationTest {
@Test
fun `check list can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, listOf(1).serialize())
val sessionData = DataSessionMessage(listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, listOf(1, 2).serialize())
val sessionData = DataSessionMessage(listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptyList<Int>().serialize())
val sessionData = DataSessionMessage(emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
}

View File

@ -6,7 +6,7 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.amqpSpecific
@ -41,7 +41,7 @@ class MapsSerializationTest {
@Test
fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap.serialize())
val sessionData = DataSessionMessage(smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize())
}

View File

@ -4,10 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.testing.kryoSpecific
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.kryoSpecific
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Rule
@ -34,17 +34,17 @@ class SetsSerializationTest {
@Test
fun `check set can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, setOf(1).serialize())
val sessionData = DataSessionMessage(setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, setOf(1, 2).serialize())
val sessionData = DataSessionMessage(setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptySet<Int>().serialize())
val sessionData = DataSessionMessage(emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
}

View File

@ -17,11 +17,11 @@ import net.corda.nodeapi.User
import net.corda.testing.*
import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.driver
import net.corda.testing.node.NotarySpec
import net.corda.testing.internal.performance.div
import net.corda.testing.internal.performance.startPublishingFixedRateInjector
import net.corda.testing.internal.performance.startReporter
import net.corda.testing.internal.performance.startTightLoopInjector
import net.corda.testing.node.NotarySpec
import org.junit.Before
import org.junit.ClassRule
import org.junit.Ignore
@ -78,7 +78,7 @@ class NodePerformanceTests : IntegrationTest() {
queueBound = 50
) {
val timing = Stopwatch.createStarted().apply {
connection.proxy.startFlow(::EmptyFlow).returnValue.get()
connection.proxy.startFlow(::EmptyFlow).returnValue.getOrThrow()
}.stop().elapsed(TimeUnit.MICROSECONDS)
timings.add(timing)
}
@ -100,13 +100,27 @@ class NodePerformanceTests : IntegrationTest() {
a as NodeHandle.InProcess
val metricRegistry = startReporter(shutdownManager, a.node.services.monitoringService.metrics)
a.rpcClientToNode().use("A", "A") { connection ->
startPublishingFixedRateInjector(metricRegistry, 8, 5.minutes, 2000L / TimeUnit.SECONDS) {
startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) {
connection.proxy.startFlow(::EmptyFlow).returnValue.get()
}
}
}
}
@Test
fun `issue flow rate`() {
driver(startNodesInProcess = true, extraCordappPackagesToScan = listOf("net.corda.finance")) {
val a = startNode(rpcUsers = listOf(User("A", "A", setOf(startFlow<CashIssueFlow>())))).get()
a as NodeHandle.InProcess
val metricRegistry = startReporter(shutdownManager, a.node.services.monitoringService.metrics)
a.rpcClientToNode().use("A", "A") { connection ->
startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) {
connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue.get()
}
}
}
}
@Test
fun `self pay rate`() {
val user = User("A", "A", setOf(startFlow<CashIssueFlow>(), startFlow<CashPaymentFlow>()))

View File

@ -1,9 +1,9 @@
package net.corda.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.random63BitValue
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.randomOrNull
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
@ -14,7 +14,9 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.*
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.send
import net.corda.node.services.transactions.RaftValidatingNotaryService
import net.corda.testing.*
import net.corda.testing.driver.DriverDSLExposedInterface
@ -28,6 +30,7 @@ import org.junit.Test
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
class P2PMessagingTest : IntegrationTest() {
@ -54,19 +57,12 @@ class P2PMessagingTest : IntegrationTest() {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val dummyTopic = "dummy.topic"
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage)
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
val responseFuture = with(alice.network) {
val request = TestRequest(replyTo = myAddress)
val responseFuture = onNext<Any>(dummyTopic, request.sessionID)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
responseFuture
}
val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0)
crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS)
// The request wasn't successful.
assertThat(responseFuture.isDone).isFalse()
@ -87,22 +83,15 @@ class P2PMessagingTest : IntegrationTest() {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val dummyTopic = "dummy.topic"
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage)
val sessionId = random63BitValue()
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
with(alice.network) {
val request = TestRequest(sessionId, myAddress)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
}
alice.receiveFrom(serviceAddress, retryId = 0)
// Wait until the first request is received
crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS)
crashingNodes.firstRequestReceived.await()
// Stop alice's node after we ensured that the first request was delivered and ignored.
alice.dispose()
val numberOfRequestsReceived = crashingNodes.requestsReceived.get()
@ -112,7 +101,12 @@ class P2PMessagingTest : IntegrationTest() {
// Restart the node and expect a response
val aliceRestarted = startAlice()
val response = aliceRestarted.network.onNext<Any>(dummyTopic, sessionId).getOrThrow(5.seconds)
val responseFuture = openFuture<Any>()
aliceRestarted.network.runOnNextMessage("test.response") {
responseFuture.set(it.data.deserialize())
}
val response = responseFuture.getOrThrow()
assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived)
assertThat(response).isEqualTo(responseMessage)
@ -138,11 +132,12 @@ class P2PMessagingTest : IntegrationTest() {
)
/**
* Sets up the [distributedServiceNodes] to respond to [dummyTopic] requests. All nodes will receive requests and
* either ignore them or respond, depending on the value of [CrashingNodes.ignoreRequests], initially set to true.
* This may be used to simulate scenarios where nodes receive request messages but crash before sending back a response.
* Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and
* either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests],
* initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash
* before sending back a response.
*/
private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, dummyTopic: String, responseMessage: String): CrashingNodes {
private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, responseMessage: String): CrashingNodes {
val crashingNodes = CrashingNodes(
requestsReceived = AtomicInteger(0),
firstRequestReceived = CountDownLatch(1),
@ -151,7 +146,7 @@ class P2PMessagingTest : IntegrationTest() {
distributedServiceNodes.forEach {
val nodeName = it.info.chooseIdentity().name
it.network.addMessageHandler(dummyTopic) { netMessage, _ ->
it.network.addMessageHandler("test.request") { netMessage, _, handler ->
crashingNodes.requestsReceived.incrementAndGet()
crashingNodes.firstRequestReceived.countDown()
// The node which receives the first request will ignore all requests
@ -163,9 +158,10 @@ class P2PMessagingTest : IntegrationTest() {
} else {
println("sending response")
val request = netMessage.data.deserialize<TestRequest>()
val response = it.network.createMessage(dummyTopic, request.sessionID, responseMessage.serialize().bytes)
val response = it.network.createMessage("test.response", responseMessage.serialize().bytes)
it.network.send(response, request.replyTo)
}
handler.acknowledge()
}
}
return crashingNodes
@ -193,19 +189,41 @@ class P2PMessagingTest : IntegrationTest() {
}
private fun StartedNode<*>.respondWith(message: Any) {
network.addMessageHandler(javaClass.name) { netMessage, _ ->
network.addMessageHandler("test.request") { netMessage, _, handle ->
val request = netMessage.data.deserialize<TestRequest>()
val response = network.createMessage(javaClass.name, request.sessionID, message.serialize().bytes)
val response = network.createMessage("test.response", message.serialize().bytes)
network.send(response, request.replyTo)
handle.acknowledge()
}
}
private fun StartedNode<*>.receiveFrom(target: MessageRecipients): CordaFuture<Any> {
val request = TestRequest(replyTo = network.myAddress)
return network.sendRequest(javaClass.name, request, target)
private fun StartedNode<*>.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture<Any> {
val response = openFuture<Any>()
network.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize())
}
network.send("test.request", TestRequest(replyTo = network.myAddress), target, retryId = retryId)
return response
}
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topic identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topic) { msg, reg, handle ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
callback(msg)
handle.acknowledge()
}
}
@CordaSerializable
private data class TestRequest(override val sessionID: Long = random63BitValue(),
override val replyTo: SingleMessageRecipient) : ServiceRequestMessage
private data class TestRequest(val replyTo: SingleMessageRecipient)
}

View File

@ -12,6 +12,7 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SignedData
import net.corda.core.crypto.sign
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.*
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
@ -280,6 +281,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
checkpointStorage,
serverThread,
database,
newSecureRandom(),
busyNodeLatch,
cordappLoader.appClassLoader
)
@ -556,7 +558,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
val database = configureDatabase(props, configuration.database, identityService, schemaService)
// Now log the vendor string as this will also cause a connection to be tested eagerly.
database.transaction {
log.info("Connected to ${database.dataSource.connection.metaData.databaseProductName} database.")
log.info("Connected to ${connection.metaData.databaseProductName} database.")
}
runOnStop += database::close
return database.transaction {

View File

@ -12,4 +12,3 @@ sealed class InitiatedFlowFactory<out F : FlowLogic<*>> {
val appName: String,
override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>()
}

View File

@ -1,42 +1,28 @@
package net.corda.node.services.api
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.Checkpoint
import java.util.stream.Stream
/**
* Thread-safe storage of fiber checkpoints.
*/
interface CheckpointStorage {
/**
* Add a new checkpoint to the store.
*/
fun addCheckpoint(checkpoint: Checkpoint)
fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>)
/**
* Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist
* in the store. Doing so will throw an [IllegalArgumentException].
*/
fun removeCheckpoint(checkpoint: Checkpoint)
fun removeCheckpoint(id: StateMachineRunId)
/**
* Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
* The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
* Return false from the block to terminate further iteration.
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
* underlying database connection is open, so any processing should happen before it is closed.
*/
fun forEach(block: (Checkpoint) -> Boolean)
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
class Checkpoint(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
val id: SecureHash get() = serializedFiber.hash
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
override fun hashCode(): Int = id.hashCode()
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>>
}

View File

@ -1,18 +1,15 @@
package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.statemachine.DeduplicationId
import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.ThreadSafe
/**
@ -27,29 +24,6 @@ import javax.annotation.concurrent.ThreadSafe
*/
@ThreadSafe
interface MessagingService {
companion object {
/**
* Session ID to use for services listening for the first message in a session (before a
* specific session ID has been established).
*/
val DEFAULT_SESSION_ID = 0L
}
/**
* The provided function will be invoked for each received message whose topic matches the given string. The callback
* will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result.
*
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/**
* The provided function will be invoked for each received message whose topic and session matches. The callback
* will run on the main server thread provided when the messaging service is constructed, and a database
@ -59,9 +33,9 @@ interface MessagingService {
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet.
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
* @param topic identifier for the topic to listen for messages arriving on.
*/
fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration
/**
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
@ -110,8 +84,6 @@ interface MessagingService {
* implementation.
*
* @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys.
* @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id.
* Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary.
* @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed
* by the broker. Note that if specified [send] itself may return earlier than the commit.
*/
@ -123,9 +95,9 @@ interface MessagingService {
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topicSession identifier for the topic and session the message is sent to.
* @param topic identifier for the topic the message is sent to.
*/
fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message
fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom())): Message
/** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -134,86 +106,12 @@ interface MessagingService {
val myAddress: SingleMessageRecipient
}
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* Must not be blank.
* @param sessionID identifier for the session the message is part of. For messages sent to services before the
* construction of a session, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.createMessage(topic: String, sessionID: Long = MessagingService.DEFAULT_SESSION_ID, data: ByteArray): Message
= createMessage(TopicSession(topic, sessionID), data)
/**
* Registers a handler for the given topic and session ID that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler], as the handler is
* automatically deregistered before the callback runs.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (ReceivedMessage) -> Unit)
= runOnNextMessage(TopicSession(topic, sessionID), callback)
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topicSession) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" }
callback(msg)
}
}
/**
* Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId.
* The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future.
*/
fun <M : Any> MessagingService.onNext(topic: String, sessionId: Long): CordaFuture<M> {
val messageFuture = openFuture<M>()
runOnNextMessage(topic, sessionId) { message ->
messageFuture.capture {
uncheckedCast(message.data.deserialize<Any>())
}
}
return messageFuture
}
fun MessagingService.send(topic: String, sessionID: Long, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID()) {
send(TopicSession(topic, sessionID), payload, to, uuid)
}
fun MessagingService.send(topicSession: TopicSession, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID(), retryId: Long? = null) {
send(createMessage(topicSession, payload.serialize().bytes, uuid), to, retryId)
}
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null)
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId)
interface MessageHandlerRegistration
/**
* An identifier for the endpoint [MessagingService] message handlers listen at.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
@CordaSerializable
data class TopicSession(val topic: String, val sessionID: Long = MessagingService.DEFAULT_SESSION_ID) {
fun isBlank() = topic.isBlank() && sessionID == MessagingService.DEFAULT_SESSION_ID
override fun toString(): String = "$topic.$sessionID"
}
/**
* A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in
* Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor
@ -226,10 +124,10 @@ data class TopicSession(val topic: String, val sessionID: Long = MessagingServic
*/
@CordaSerializable
interface Message {
val topicSession: TopicSession
val data: ByteArray
val topic: String
val data: ByteSequence
val debugTimestamp: Instant
val uniqueMessageId: UUID
val uniqueMessageId: DeduplicationId
}
// TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that.
@ -248,3 +146,20 @@ object TopicStringValidator {
/** @throws IllegalArgumentException if the given topic contains invalid characters */
fun check(tag: String) = require(regex.matcher(tag).matches())
}
/**
* Represents a to-be-acknowledged message. It has an associated deduplication ID.
*/
interface AcknowledgeHandle {
/**
* Acknowledge the message.
*/
fun acknowledge()
/**
* Store the deduplication ID. TODO this should be moved into the flow state machine completely.
*/
fun persistDeduplicationId()
}
typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, AcknowledgeHandle) -> Unit

View File

@ -10,13 +10,11 @@ import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.sequence
import net.corda.core.utilities.trace
import net.corda.core.utilities.*
import net.corda.node.VersionInfo
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.StateMachineManagerImpl
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.FlowMessagingImpl
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.node.utilities.PersistentMap
@ -33,15 +31,16 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
import java.security.PublicKey
import java.time.Instant
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.ThreadSafe
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Lob
// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox
/**
* This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product.
* Artemis is a message queue broker and here we run a client connecting to the specified broker instance
@ -77,20 +76,19 @@ class P2PMessagingClient(config: NodeConfiguration,
// that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid
// confusion.
private val topicProperty = SimpleString("platform-topic")
private val sessionIdProperty = SimpleString("session-id")
private val cordaVendorProperty = SimpleString("corda-vendor")
private val releaseVersionProperty = SimpleString("release-version")
private val platformVersionProperty = SimpleString("platform-version")
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
private val messageMaxRetryCount: Int = 3
fun createProcessedMessage(): AppendOnlyPersistentMap<UUID, Instant, ProcessedMessage, String> {
fun createProcessedMessages(): AppendOnlyPersistentMap<DeduplicationId, Instant, ProcessedMessage, String> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { it.toString() },
fromPersistentEntity = { Pair(UUID.fromString(it.uuid), it.insertionTime) },
toPersistentEntity = { key: UUID, value: Instant ->
toPersistentEntityKey = { it.toString },
fromPersistentEntity = { Pair(DeduplicationId(it.id), it.insertionTime) },
toPersistentEntity = { key: DeduplicationId, value: Instant ->
ProcessedMessage().apply {
uuid = key.toString()
id = key.toString
insertionTime = value
}
},
@ -118,9 +116,9 @@ class P2PMessagingClient(config: NodeConfiguration,
)
}
private class NodeClientMessage(override val topicSession: TopicSession, override val data: ByteArray, override val uniqueMessageId: UUID) : Message {
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId) : Message {
override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topicSession#${String(data)}"
override fun toString() = "$topic#${String(data.bytes)}"
}
}
@ -136,8 +134,7 @@ class P2PMessagingClient(config: NodeConfiguration,
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
/** A registration to handle messages of different types */
data class Handler(val topicSession: TopicSession,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
private val cordaVendor = SimpleString(versionInfo.vendor)
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
@ -148,16 +145,17 @@ class P2PMessagingClient(config: NodeConfiguration,
private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong()
private val artemis = ArtemisMessagingClient(config, serverAddress)
private val state = ThreadBox(InnerState())
private val handlers = CopyOnWriteArrayList<Handler>()
private val processedMessages = createProcessedMessage()
private val handlers = ConcurrentHashMap<String, MessageHandler>()
private val processedMessages = createProcessedMessages()
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage(
@Id
@Column(name = "message_id", length = 36)
var uuid: String = "",
@Column(name = "message_id", length = 64)
var id: String = "",
@Column(name = "insertion_time")
var insertionTime: Instant = Instant.now()
@ -167,7 +165,7 @@ class P2PMessagingClient(config: NodeConfiguration,
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
class RetryMessage(
@Id
@Column(name = "message_id", length = 36)
@Column(name = "message_id", length = 64)
var key: Long = 0,
@Lob
@ -214,22 +212,7 @@ class P2PMessagingClient(config: NodeConfiguration,
val message: ReceivedMessage? = artemisToCordaMessage(artemisMessage)
if (message != null)
deliver(message)
// Ack the message so it won't be redelivered. We should only really do this when there were no
// transient failures. If we caught an exception in the handler, we could back off and retry delivery
// a few times before giving up and redirecting the message to a dead-letter address for admin or
// developer inspection. Artemis has the features to do this for us, we just need to enable them.
//
// TODO: Setup Artemis delayed redelivery and dead letter addresses.
//
// ACKing a message calls back into the session which isn't thread safe, so we have to ensure it
// doesn't collide with a send here. Note that stop() could have been called whilst we were
// processing a message but if so, it'll be parked waiting for us to count down the latch, so
// the session itself is still around and we can still ack messages as a result.
state.locked {
artemisMessage.acknowledge()
}
deliver(artemisMessage, message)
return true
}
@ -255,14 +238,13 @@ class P2PMessagingClient(config: NodeConfiguration,
private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? {
try {
val topic = message.required(topicProperty) { getStringProperty(it) }
val sessionID = message.required(sessionIdProperty) { getLongProperty(it) }
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { UUID.fromString(message.getStringProperty(it)) }
log.trace { "Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid" }
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) }
log.trace { "Received message from: ${message.address} user: $user topic: $topic id: $uniqueMessageId" }
return ArtemisReceivedMessage(TopicSession(topic, sessionID), CordaX500Name.parse(user), platformVersion, uuid, message)
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uniqueMessageId, message)
} catch (e: Exception) {
log.error("Unable to process message, ignoring it: $message", e)
return null
@ -274,21 +256,19 @@ class P2PMessagingClient(config: NodeConfiguration,
return extractor(key)
}
private class ArtemisReceivedMessage(override val topicSession: TopicSession,
private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name,
override val platformVersion: Int,
override val uniqueMessageId: UUID,
override val uniqueMessageId: DeduplicationId,
private val message: ClientMessage) : ReceivedMessage {
override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } }
override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) }
override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp)
override fun toString() = "${topicSession.topic}#${data.sequence()}"
override fun toString() = "$topic#$data"
}
private fun deliver(msg: ReceivedMessage): Boolean {
private fun deliver(artemisMessage: ClientMessage, msg: ReceivedMessage) {
state.checkNotLocked()
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { it.topicSession.isBlank() || it.topicSession == msg.topicSession }
val deliverTo = handlers[msg.topic]
try {
// This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will
// be slow, and Artemis can handle that case intelligently. We don't just invoke the handler
@ -298,31 +278,34 @@ class P2PMessagingClient(config: NodeConfiguration,
//
// Note that handlers may re-enter this class. We aren't holding any locks and methods like
// start/run/stop have re-entrancy assertions at the top, so it is OK.
nodeExecutor.fetchFrom {
database.transaction {
if (msg.uniqueMessageId in processedMessages) {
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topicSession}" }
} else {
if (deliverTo.isEmpty()) {
// TODO: Implement dead letter queue, and send it there.
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topicSession} that doesn't have any registered handlers yet")
} else {
callHandlers(msg, deliverTo)
if (deliverTo != null) {
val isDuplicate = database.transaction { msg.uniqueMessageId in processedMessages }
if (isDuplicate) {
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" }
return
}
// TODO We will at some point need to decide a trimming policy for the id's
val acknowledgeHandle = object : AcknowledgeHandle {
override fun persistDeduplicationId() {
processedMessages[msg.uniqueMessageId] = Instant.now()
}
// ACKing a message calls back into the session which isn't thread safe, so we have to ensure it
// doesn't collide with a send here. Note that stop() could have been called whilst we were
// processing a message but if so, it'll be parked waiting for us to count down the latch, so
// the session itself is still around and we can still ack messages as a result.
override fun acknowledge() {
state.locked {
artemisMessage.individualAcknowledge()
artemis.started!!.session.commit()
}
}
}
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandle)
} else {
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
}
} catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e)
}
return true
}
private fun callHandlers(msg: ReceivedMessage, deliverTo: List<Handler>) {
for (handler in deliverTo) {
handler.callback(msg, handler)
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
}
@ -370,20 +353,19 @@ class P2PMessagingClient(config: NodeConfiguration,
putStringProperty(cordaVendorProperty, cordaVendor)
putStringProperty(releaseVersionProperty, releaseVersion)
putIntProperty(platformVersionProperty, versionInfo.platformVersion)
putStringProperty(topicProperty, SimpleString(message.topicSession.topic))
putLongProperty(sessionIdProperty, message.topicSession.sessionID)
writeBodyBufferBytes(message.data)
putStringProperty(topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) {
if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) {
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
}
log.trace {
"Send to: $mqAddress topic: ${message.topicSession.topic} " +
"sessionID: ${message.topicSession.sessionID} uuid: ${message.uniqueMessageId}"
"Send to: $mqAddress topic: ${message.topic} " +
"sessionID: ${message.topic} id: ${message.uniqueMessageId}"
}
artemis.producer.send(mqAddress, artemisMessage)
retryId?.let {
@ -467,30 +449,26 @@ class P2PMessagingClient(config: NodeConfiguration,
}
}
override fun addMessageHandler(topic: String,
sessionID: Long,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
return addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
handlers.compute(topic) { _, handler ->
if (handler != null) {
throw IllegalStateException("Cannot add another acking handler for $topic, there is already an acking one")
}
override fun addMessageHandler(topicSession: TopicSession,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(topicSession, callback)
handlers.add(handler)
return handler
callback
}
return HandlerRegistration(topic, callback)
}
override fun removeMessageHandler(registration: MessageHandlerRegistration) {
handlers.remove(registration)
registration as HandlerRegistration
handlers.remove(registration.topic)
}
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return NodeClientMessage(topicSession, data, uuid)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId)
}
// TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port)
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
return when (partyInfo) {
is PartyInfo.SingleNode -> NodeAddress(partyInfo.party.owningKey, partyInfo.addresses.first())

View File

@ -27,6 +27,7 @@ class RPCMessagingClient(private val config: SSLConfiguration, serverAddress: Ne
}
fun stop() = synchronized(this) {
rpcServer?.close()
artemis.stop()
}
}

View File

@ -1,27 +0,0 @@
package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.serialization.CordaSerializable
/**
* Abstract superclass for request messages sent to services which expect a reply.
*/
@CordaSerializable
interface ServiceRequestMessage {
val sessionID: Long
val replyTo: SingleMessageRecipient
}
/**
* Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response.
* @param R The type of the response.
*/
fun <R : Any> MessagingService.sendRequest(topic: String,
request: ServiceRequestMessage,
target: MessageRecipients): CordaFuture<R> {
val responseFuture = onNext<R>(topic, request.sessionID)
send(topic, MessagingService.DEFAULT_SESSION_ID, request, target)
return responseFuture
}

View File

@ -1,10 +1,14 @@
package net.corda.node.services.persistence
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.Checkpoint
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession
import java.util.*
import java.util.stream.Stream
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
@ -27,32 +31,29 @@ class DBCheckpointStorage : CheckpointStorage {
var checkpoint: ByteArray = ByteArray(0)
)
override fun addCheckpoint(checkpoint: Checkpoint) {
currentDBSession().save(DBCheckpoint().apply {
checkpointId = checkpoint.id.toString()
this.checkpoint = checkpoint.serializedFiber.bytes
override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) {
currentDBSession().saveOrUpdate(DBCheckpoint().apply {
checkpointId = id.uuid.toString()
this.checkpoint = checkpoint.bytes
})
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val session = currentDBSession()
override fun removeCheckpoint(id: StateMachineRunId) {
val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder
val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java)
val root = delete.from(DBCheckpoint::class.java)
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), checkpoint.id.toString()))
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), id.uuid.toString()))
session.createQuery(delete).executeUpdate()
}
override fun forEach(block: (Checkpoint) -> Boolean) {
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> {
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
val root = criteriaQuery.from(DBCheckpoint::class.java)
criteriaQuery.select(root)
for (row in session.createQuery(criteriaQuery).resultList) {
val checkpoint = Checkpoint(SerializedBytes(row.checkpoint))
if (!block(checkpoint)) {
break
}
return session.createQuery(criteriaQuery).stream().map {
StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes<Checkpoint>(it.checkpoint)
}
}
}

View File

@ -1,13 +1,20 @@
package net.corda.node.services.persistence
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.utilities.*
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
@ -48,22 +55,37 @@ class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsTok
}
}
private val txStorage = createTransactionsMap()
private val txStorage = ThreadBox(createTransactionsMap())
override fun addTransaction(transaction: SignedTransaction): Boolean =
txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply {
txStorage.locked {
addWithDuplicatesAllowed(transaction.id, transaction).apply {
updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction)
}
}
override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id]
override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage.content[id]
private val updatesPublisher = PublishSubject.create<SignedTransaction>().toSerialized()
override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction()
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> =
DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return txStorage.locked {
DataFeed(allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
}
}
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return txStorage.locked {
val existingTransaction = get(id)
if (existingTransaction == null) {
updatesPublisher.filter { it.id == id }.toFuture()
} else {
doneFuture(existingTransaction)
}
}
}
@VisibleForTesting
val transactions: Iterable<SignedTransaction>
get() = txStorage.allPersisted().map { it.second }.toList()
val transactions: Iterable<SignedTransaction> get() = txStorage.content.allPersisted().map { it.second }.toList()
}

View File

@ -0,0 +1,126 @@
package net.corda.node.services.statemachine
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.node.services.messaging.AcknowledgeHandle
import java.time.Instant
/**
* [Action]s are reified IO actions to execute as part of state machine transitions.
*/
sealed class Action {
/**
* Track a transaction hash and notify the state machine once the corresponding transaction has committed.
*/
data class TrackTransaction(val hash: SecureHash) : Action()
/**
* Send an initial session message to [party].
*/
data class SendInitial(
val party: Party,
val initialise: InitialSessionMessage,
val deduplicationId: DeduplicationId
) : Action()
/**
* Send a session message to a [peerParty] with which we have an established session.
*/
data class SendExisting(
val peerParty: Party,
val message: ExistingSessionMessage,
val deduplicationId: DeduplicationId
) : Action()
/**
* Persist the specified [checkpoint].
*/
data class PersistCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) : Action()
/**
* Remove the checkpoint corresponding to [id].
*/
data class RemoveCheckpoint(val id: StateMachineRunId) : Action()
/**
* Persist the deduplication IDs of [acknowledgeHandles].
*/
data class PersistDeduplicationIds(val acknowledgeHandles: List<AcknowledgeHandle>) : Action()
/**
* Acknowledge messages in [acknowledgeHandles].
*/
data class AcknowledgeMessages(val acknowledgeHandles: List<AcknowledgeHandle>) : Action()
/**
* Propagate [errorMessages] to [sessions].
* @param sessions a map from source session IDs to initiated sessions.
*/
data class PropagateErrors(
val errorMessages: List<ErrorSessionMessage>,
val sessions: List<SessionState.Initiated>
) : Action()
/**
* Create a session binding from [sessionId] to [flowId] to allow routing of incoming messages.
*/
data class AddSessionBinding(val flowId: StateMachineRunId, val sessionId: SessionId) : Action()
/**
* Remove the session bindings corresponding to [sessionIds].
*/
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
/**
* Signal that the flow corresponding to [flowId] is considered started.
*/
data class SignalFlowHasStarted(val flowId: StateMachineRunId) : Action()
/**
* Remove the flow corresponding to [flowId].
*/
data class RemoveFlow(
val flowId: StateMachineRunId,
val removalReason: FlowRemovalReason,
val lastState: StateMachineState
) : Action()
/**
* Schedule [event] to self.
*/
data class ScheduleEvent(val event: Event) : Action()
/**
* Sleep until [time].
*/
data class SleepUntil(val time: Instant) : Action()
/**
* Create a new database transaction.
*/
object CreateTransaction : Action() { override fun toString() = "CreateTransaction" }
/**
* Roll back the current database transaction.
*/
object RollbackTransaction : Action() { override fun toString() = "RollbackTransaction" }
/**
* Commit the current database transaction.
*/
object CommitTransaction : Action() { override fun toString() = "CommitTransaction" }
}
/**
* Reason for flow removal.
*/
sealed class FlowRemovalReason {
data class OrderlyFinish(val flowReturnValue: Any?) : FlowRemovalReason()
data class ErrorFinish(val flowErrors: List<FlowError>) : FlowRemovalReason()
object SoftShutdown : FlowRemovalReason() { override fun toString() = "SoftShutdown" }
// TODO Should we remove errored flows? How will the flow hospital work? Perhaps keep them in memory for a while, flush
// them after a timeout, reload them on flow hospital request. In any case if we ever want to remove them
// (e.g. temporarily) then add a case for that here.
}

View File

@ -0,0 +1,15 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
/**
* An executor of a single [Action].
*/
interface ActionExecutor {
/**
* Execute [action] by [fiber].
* Precondition: [executeAction] is run inside an open database transaction.
*/
@Suspendable
fun executeAction(fiber: FlowFiber, action: Action)
}

View File

@ -0,0 +1,190 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import java.time.Duration
import java.time.Instant
import java.util.concurrent.TimeUnit
/**
* This is the bottom execution engine of flow side-effects.
*/
class ActionExecutorImpl(
private val services: ServiceHubInternal,
private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext
) : ActionExecutor {
private companion object {
val log = contextLogger()
}
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
log.trace { "Flow ${fiber.id} executing $action" }
return when (action) {
is Action.TrackTransaction -> executeTrackTransaction(fiber, action)
is Action.PersistCheckpoint -> executePersistCheckpoint(action)
is Action.PersistDeduplicationIds -> executePersistDeduplicationIds(action)
is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action)
is Action.PropagateErrors -> executePropagateErrors(action)
is Action.ScheduleEvent -> executeScheduleEvent(fiber, action)
is Action.SleepUntil -> executeSleepUntil(action)
is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action)
is Action.SendInitial -> executeSendInitial(action)
is Action.SendExisting -> executeSendExisting(action)
is Action.AddSessionBinding -> executeAddSessionBinding(action)
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
is Action.RemoveFlow -> executeRemoveFlow(action)
is Action.CreateTransaction -> executeCreateTransaction()
is Action.RollbackTransaction -> executeRollbackTransaction()
is Action.CommitTransaction -> executeCommitTransaction()
}
}
@Suspendable
private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) {
services.validatedTransactions.trackTransaction(action.hash).thenMatch(
success = { transaction ->
fiber.scheduleEvent(Event.TransactionCommitted(transaction))
},
failure = { exception ->
fiber.scheduleEvent(Event.Error(exception))
}
)
}
@Suspendable
private fun executePersistCheckpoint(action: Action.PersistCheckpoint) {
val checkpointBytes = serializeCheckpoint(action.checkpoint)
checkpointStorage.addCheckpoint(action.id, checkpointBytes)
}
@Suspendable
private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationIds) {
for (handle in action.acknowledgeHandles) {
handle.persistDeduplicationId()
}
}
@Suspendable
private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) {
action.acknowledgeHandles.forEach {
it.acknowledge()
}
}
@Suspendable
private fun executePropagateErrors(action: Action.PropagateErrors) {
action.errorMessages.forEach { error ->
val exception = error.flowException
log.debug("Propagating error", exception)
}
val pendingSendAcks = CountUpDownLatch(0)
for (sessionState in action.sessions) {
// We cannot propagate if the session isn't live.
if (sessionState.initiatedState !is InitiatedSessionState.Live) {
continue
}
// Don't propagate errors to the originating session
for (errorMessage in action.errorMessages) {
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
pendingSendAcks.countUp()
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) {
pendingSendAcks.countDown()
}
}
}
// TODO we simply block here, perhaps this should be explicit in the worker state
pendingSendAcks.await()
}
@Suspendable
private fun executeScheduleEvent(fiber: FlowFiber, action: Action.ScheduleEvent) {
fiber.scheduleEvent(action.event)
}
@Suspendable
private fun executeSleepUntil(action: Action.SleepUntil) {
// TODO introduce explicit sleep state + wakeup event instead of relying on Fiber.sleep. This is so shutdown
// conditions may "interrupt" the sleep instead of waiting until wakeup.
val duration = Duration.between(Instant.now(), action.time)
Fiber.sleep(duration.toNanos(), TimeUnit.NANOSECONDS)
}
@Suspendable
private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) {
checkpointStorage.removeCheckpoint(action.id)
}
@Suspendable
private fun executeSendInitial(action: Action.SendInitial) {
flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId, null)
}
@Suspendable
private fun executeSendExisting(action: Action.SendExisting) {
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId, null)
}
@Suspendable
private fun executeAddSessionBinding(action: Action.AddSessionBinding) {
stateMachineManager.addSessionBinding(action.flowId, action.sessionId)
}
@Suspendable
private fun executeRemoveSessionBindings(action: Action.RemoveSessionBindings) {
stateMachineManager.removeSessionBindings(action.sessionIds)
}
@Suspendable
private fun executeSignalFlowHasStarted(action: Action.SignalFlowHasStarted) {
stateMachineManager.signalFlowHasStarted(action.flowId)
}
@Suspendable
private fun executeRemoveFlow(action: Action.RemoveFlow) {
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)
}
@Suspendable
private fun executeCreateTransaction() {
if (DatabaseTransactionManager.currentOrNull() != null) {
throw IllegalStateException("Refusing to create a second transaction")
}
DatabaseTransactionManager.newTransaction()
}
@Suspendable
private fun executeRollbackTransaction() {
DatabaseTransactionManager.currentOrNull()?.close()
}
@Suspendable
private fun executeCommitTransaction() {
try {
DatabaseTransactionManager.current().commit()
} finally {
DatabaseTransactionManager.current().close()
DatabaseTransactionManager.setThreadLocalTx(null)
}
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
}
}

View File

@ -0,0 +1,66 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.strands.concurrent.AbstractQueuedSynchronizer
import co.paralleluniverse.fibers.Suspendable
/**
* Quasar-compatible latch that may be incremented.
*/
class CountUpDownLatch(initialValue: Int) {
// See quasar CountDownLatch
private class Sync(initialValue: Int) : AbstractQueuedSynchronizer() {
init {
state = initialValue
}
override fun tryAcquireShared(arg: Int): Int {
if (arg >= 0) {
return if (state == arg) 1 else -1
} else {
return if (state <= -arg) 1 else -1
}
}
override fun tryReleaseShared(arg: Int): Boolean {
while (true) {
val c = state
if (c == 0)
return false
val nextc = c - Math.min(c, arg)
if (compareAndSetState(c, nextc))
return nextc == 0
}
}
fun increment() {
while (true) {
val c = state
val nextc = c + 1
if (compareAndSetState(c, nextc))
return
}
}
}
private val sync = Sync(initialValue)
@Suspendable
fun await() {
sync.acquireSharedInterruptibly(0)
}
@Suspendable
fun awaitLessThanOrEqual(number: Int) {
sync.acquireSharedInterruptibly(number)
}
fun countDown(number: Int = 1) {
require(number > 0)
sync.releaseShared(number)
}
fun countUp() {
sync.increment()
}
}

View File

@ -0,0 +1,47 @@
package net.corda.node.services.statemachine
import java.security.SecureRandom
/**
* A deduplication ID of a flow message.
*/
data class DeduplicationId(val toString: String) {
companion object {
/**
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
* unless we persist the ID somehow.
*/
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
/**
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
* creating IDs in case the message-generating flow logic is replayed on hard failure.
*
* A normal deduplication ID consists of:
* 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the
* initiator's session ID.
* 2. The number of *clean* suspends since the start of the flow.
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
* message-id map to change, which means deduplication will not happen correctly.
*/
fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId {
return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index")
}
/**
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
* dirtiness state to be thrown away for resumption.
*
* An error deduplication ID consists of:
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
* See [net.corda.core.flows.IdentifiableException].
* 2. The recipient's session ID.
*/
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
}
}
}

View File

@ -0,0 +1,117 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.messaging.AcknowledgeHandle
/**
* Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from
* outside (e.g. in case of message delivery or external event).
*/
sealed class Event {
/**
* Check the current state for pending work. For example if the flow is waiting for a message from a particular
* session this event may cause a flow resume if we have a corresponding message. In general the state machine
* should be idempotent in the [DoRemainingWork] event, meaning a second subsequent event shouldn't modify the state
* or produce [Action]s.
*/
object DoRemainingWork : Event() { override fun toString() = "DoRemainingWork" }
/**
* Deliver a session message.
* @param sessionMessage the message itself.
* @param acknowledgeHandle the handle to acknowledge the message after checkpointing.
* @param sender the sender [Party].
*/
data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage,
val acknowledgeHandle: AcknowledgeHandle,
val sender: Party
) : Event()
/**
* Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error.
* @param exception the exception itself.
*/
data class Error(val exception: Throwable) : Event()
/**
* Signal that a ledger transaction has committed. This is an event completing a [FlowIORequest.WaitForLedgerCommit]
* suspension.
* @param transaction the transaction that was committed.
*/
data class TransactionCommitted(val transaction: SignedTransaction) : Event()
/**
* Trigger a soft shutdown, removing the flow as soon as possible. This causes the flow to be removed as soon as
* this event is processed. Note that on restart the flow will resume as normal.
*/
object SoftShutdown : Event() { override fun toString() = "SoftShutdown" }
/**
* Start error propagation on a errored flow. This may be triggered by e.g. a [FlowHospital].
*/
object StartErrorPropagation : Event() { override fun toString() = "StartErrorPropagation" }
/**
*
* Scheduled by the flow.
*
* Initiate a flow. This causes a new session object to be created and returned to the flow. Note that no actual
* communication takes place at this time, only on the first send/receive operation on the session.
* @param party the [Party] to create a session with.
*/
data class InitiateFlow(val party: Party) : Event()
/**
* Signal the entering into a subflow.
*
* Scheduled and executed by the flow.
*
* @param subFlowClass the [Class] of the subflow, to be used to determine whether it's Initiating or inlined.
*/
data class EnterSubFlow(val subFlowClass: Class<FlowLogic<*>>) : Event()
/**
* Signal the leaving of a subflow.
*
* Scheduled by the flow.
*
*/
object LeaveSubFlow : Event() { override fun toString() = "LeaveSubFlow" }
/**
* Signal a flow suspension. This causes the flow's stack and the state machine's state together with the suspending
* IO request to be persisted into the database.
*
* Scheduled by the flow and executed inside the park closure.
*
* @param ioRequest the request triggering the suspension.
* @param maySkipCheckpoint indicates whether the persistence may be skipped.
* @param fiber the serialised stack of the flow.
*/
data class Suspend(
val ioRequest: FlowIORequest<*>,
val maySkipCheckpoint: Boolean,
val fiber: SerializedBytes<FlowStateMachineImpl<*>>
) : Event() {
override fun toString() =
"Suspend(" +
"ioRequest=$ioRequest, " +
"maySkipCheckpoint=$maySkipCheckpoint, " +
"fiber=${fiber.hash}, " +
")"
}
/**
* Signals clean flow finish.
*
* Scheduled by the flow.
*
* @param returnValue the return value of the flow.
*/
data class FlowFinish(val returnValue: Any?) : Event()
}

View File

@ -0,0 +1,18 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.transitions.StateMachine
/**
* An interface wrapping a fiber running a flow.
*/
interface FlowFiber {
val id: StateMachineRunId
val stateMachine: StateMachine
@Suspendable
fun scheduleEvent(event: Event)
fun snapshot(): StateMachineState
}

View File

@ -0,0 +1,18 @@
package net.corda.node.services.statemachine
/**
* A flow hospital is a class that is notified when a flow transitions into an error state due to an uncaught exception
* or internal error condition, and when it becomes clean again (e.g. due to a resume).
* Also see [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor].
*/
interface FlowHospital {
/**
* The flow running in [flowFiber] has errored.
*/
fun flowErrored(flowFiber: FlowFiber)
/**
* The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume.
*/
fun flowCleaned(flowFiber: FlowFiber)
}

View File

@ -1,121 +0,0 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash
import java.time.Instant
interface FlowIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot
}
interface WaitingRequest : FlowIORequest {
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
}
interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSessionInternal
}
interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage
}
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T>
val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
}
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
override val message: SessionMessage,
override val receiveType: Class<T>,
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInternal,
override val receiveType: Class<T>,
override val userReceiveType: Class<*>?) : ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
return received.keys == requests.map { it.session }.toSet()
}
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
}
@Suspendable
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
poll(receivedMessages)
return if (isComplete(receivedMessages)) {
receivedMessages
} else {
suspend(this)
poll(receivedMessages)
if (isComplete(receivedMessages)) {
receivedMessages
} else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
}
}
}
interface Suspend {
@Suspendable
operator fun invoke(request: FlowIORequest)
}
@Suspendable
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
poll(request)?.let {
receivedMessages[request.session] = RequestMessage(request, it)
}
}
}
@Suspendable
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
return request.session.receivedMessages.poll()
}
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
}
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
}
data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -0,0 +1,75 @@
package net.corda.node.services.statemachine
import com.esotericsoftware.kryo.KryoException
import net.corda.core.flows.FlowException
import net.corda.core.identity.Party
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.AcknowledgeHandle
import net.corda.node.services.messaging.ReceivedMessage
import java.io.NotSerializableException
/**
* A wrapper interface around flow messaging.
*/
interface FlowMessaging {
/**
* Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to
* listen on the send acknowledgement.
*/
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?)
/**
* Start the messaging using the [onMessage] message handler.
*/
fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit)
}
/**
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
*/
class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
companion object {
val log = contextLogger()
val sessionTopic = "platform.session"
}
override fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) {
serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, acknowledgeHandle ->
onMessage(receivedMessage, acknowledgeHandle)
}
}
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) {
log.trace { "Sending message $deduplicationId $message to party $party" }
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId)
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val sequenceKey = when (message) {
is InitialSessionMessage -> message.initiatorSessionId
is ExistingSessionMessage -> message.recipientSessionId
}
serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey, acknowledgementHandler = acknowledgementHandler)
}
private fun serializeSessionMessage(message: SessionMessage): SerializedBytes<SessionMessage> {
return try {
message.serialize()
} catch (exception: Exception) {
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
if ((exception is KryoException || exception is NotSerializableException)
&& message is ExistingSessionMessage && message.payload is ErrorSessionMessage) {
val error = message.payload.flowException
val rewrappedError = FlowException(error?.message)
message.copy(payload = message.payload.copy(flowException = rewrappedError)).serialize()
} else {
throw exception
}
}
}
}

View File

@ -1,20 +1,39 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.checkPayloadIs
class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
internal lateinit var stateMachine: FlowStateMachine<*>
internal lateinit var sessionFlow: FlowLogic<*>
class FlowSessionImpl(
override val counterparty: Party,
val sourceSessionId: SessionId
) : FlowSession() {
override fun toString() = "FlowSessionImpl(counterparty=$counterparty, sourceSessionId=$sourceSessionId)"
override fun equals(other: Any?): Boolean {
return (other as? FlowSessionImpl)?.sourceSessionId == sourceSessionId
}
override fun hashCode() = sourceSessionId.hashCode()
private fun getFlowStateMachine(): FlowStateMachine<*> {
return Fiber.currentFiber() as FlowStateMachine<*>
}
@Suspendable
override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo {
return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint)
val request = FlowIORequest.GetFlowInfo(NonEmptySet.of(this))
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!
}
@Suspendable
@ -26,14 +45,12 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
payload: Any,
maySkipCheckpoint: Boolean
): UntrustworthyData<R> {
return stateMachine.sendAndReceive(
receiveType,
counterparty,
payload,
sessionFlow,
retrySend = false,
maySkipCheckpoint = maySkipCheckpoint
enforceNotPrimitive(receiveType)
val request = FlowIORequest.SendAndReceive(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = false
)
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType)
}
@Suspendable
@ -41,7 +58,9 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
@Suspendable
override fun <R : Any> receive(receiveType: Class<R>, maySkipCheckpoint: Boolean): UntrustworthyData<R> {
return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint)
enforceNotPrimitive(receiveType)
val request = FlowIORequest.Receive(NonEmptySet.of(this))
return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType)
}
@Suspendable
@ -49,12 +68,18 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() {
@Suspendable
override fun send(payload: Any, maySkipCheckpoint: Boolean) {
return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint)
val request = FlowIORequest.Send(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)),
shouldRetrySend = false
)
return getFlowStateMachine().suspend(request, maySkipCheckpoint)
}
@Suspendable
override fun send(payload: Any) = send(payload, maySkipCheckpoint = false)
override fun toString() = "Flow session with $counterparty"
private fun enforceNotPrimitive(type: Class<*>) {
require(!type.isPrimitive) { "Cannot receive primitive type $type" }
}
}

View File

@ -1,56 +0,0 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party
import net.corda.node.services.statemachine.FlowSessionState.Initiated
import net.corda.node.services.statemachine.FlowSessionState.Initiating
import java.util.concurrent.ConcurrentLinkedQueue
/**
* @param retryable Indicates that the session initialisation should be retried until an expected [SessionData] response
* is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow
* that only sends back a single [SessionData] message before termination.
*/
// TODO rename this
class FlowSessionInternal(
val flow: FlowLogic<*>,
val flowSession : FlowSession,
val ourSessionId: Long,
val initiatingParty: Party?,
var state: FlowSessionState,
var retryable: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<*>>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
override fun toString(): String {
return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)"
}
}
/**
* [FlowSessionState] describes the session's state.
*
* [Uninitiated] is pre-handshake, where no communication has happened. [Initiating.otherParty] at this point holds a
* [Party] corresponding to either a specific peer or a service.
* [Initiating] is pre-handshake, where the initiating message has been sent.
* [Initiated] is post-handshake. At this point [Initiating.otherParty] will have been resolved to a specific peer
* [Initiated.peerParty], and the peer's sessionId has been initialised.
*/
sealed class FlowSessionState {
abstract val sendToParty: Party
data class Uninitiated(val otherParty: Party) : FlowSessionState() {
override val sendToParty: Party get() = otherParty
}
/** [otherParty] may be a specific peer or a service party */
data class Initiating(val otherParty: Party) : FlowSessionState() {
override val sendToParty: Party get() = otherParty
}
data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() {
override val sendToParty: Party get() = peerParty
}
}

View File

@ -5,249 +5,184 @@ import co.paralleluniverse.fibers.Fiber.parkAndSerialize
import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.primitives.Primitives
import co.paralleluniverse.strands.channels.Channel
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.isRegularFile
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.*
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
import net.corda.node.services.api.FlowAppAuditEvent
import net.corda.node.services.api.FlowPermissionAuditEvent
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.logging.pushToLoggingContext
import net.corda.node.services.statemachine.FlowSessionState.Initiating
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.file.Paths
import java.sql.SQLException
import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty1
class FlowPermissionException(message: String) : FlowException(message)
class TransientReference<out A>(@Transient val value: A)
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
override val logic: FlowLogic<R>,
scheduler: FiberScheduler,
val ourIdentity: Party,
override val context: InvocationContext) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R> {
scheduler: FiberScheduler
// Store the Party rather than the full cert path with PartyAndCertificate
) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R>, FlowFiber {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = Fiber::class.staticField<Any>("SERIALIZER_BLOCKER").value
/**
* Return the current [FlowStateMachineImpl] or null if executing outside of one.
*/
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
@Suspendable
private fun abortFiber(): Nothing {
Fiber.park()
throw IllegalStateException("Ended fiber unparked")
}
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient override lateinit var serviceHub: ServiceHubInternal
@Transient override lateinit var ourIdentityAndCert: PartyAndCertificate
@Transient internal lateinit var database: CordaPersistence
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: (Try<R>, Boolean) -> Unit
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var txTrampoline: DatabaseTransaction? = null
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = DatabaseTransactionManager.current()
DatabaseTransactionManager.setThreadLocalTx(null)
return TransientReference(transaction)
}
}
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
data class TransientValues(
val eventQueue: Channel<Event>,
val resultFuture: CordaFuture<Any?>,
val database: CordaPersistence,
val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext
)
internal var transientValues: TransientReference<TransientValues>? = null
internal var transientState: TransientReference<StateMachineState>? = null
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
return field.get(suppliedValues.value)
}
/**
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary.
*/
override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id")
@Transient private var resultFutureTransient: OpenFuture<R>? = openFuture()
private val _resultFuture get() = resultFutureTransient ?: openFuture<R>().also { resultFutureTransient = it }
/** This future will complete when the call method returns. */
override val resultFuture: CordaFuture<R> get() = _resultFuture
// This state IS serialised, as we need it to know what the fiber is waiting for.
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSessionInternal>()
internal var waitingForResponse: WaitingRequest? = null
override val logger = log
override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture))
override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext
override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity
internal var hasSoftLockedStates: Boolean = false
set(value) {
if (value) field = value else throw IllegalArgumentException("Can only set to true")
}
init {
logic.stateMachine = this
@Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
val stateMachine = getTransientField(TransientValues::stateMachine)
val oldState = transientState!!.value
val actionExecutor = getTransientField(TransientValues::actionExecutor)
val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
transientState = TransientReference(newState)
return continuation
}
@Suspendable
private fun processEventsUntilFlowIsResumed(): Any? {
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val eventQueue = getTransientField(TransientValues::eventQueue)
eventLoop@while (true) {
val nextEvent = eventQueue.receive()
val continuation = processEvent(transitionExecutor, nextEvent)
when (continuation) {
is FlowContinuation.Resume -> return continuation.result
is FlowContinuation.Throw -> {
continuation.throwable.fillInStackTrace()
throw continuation.throwable
}
FlowContinuation.ProcessEvents -> continue@eventLoop
FlowContinuation.Abort -> abortFiber()
}
}
}
@Suspendable
override fun run() {
createTransaction()
logic.stateMachine = this
context.pushToLoggingContext()
initialiseFlow()
logger.debug { "Calling flow: $logic" }
val startTime = System.nanoTime()
val result = try {
val r = logic.call()
// Only sessions which have done a single send and nothing else will block here
openSessions.values
.filter { it.state is Initiating }
.forEach { it.waitForConfirmation() }
r
} catch (e: FlowException) {
recordDuration(startTime, success = false)
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
val propagated = e.stackTrace[0].className == javaClass.name
processException(e, propagated)
logger.warn(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e)
return
} catch (t: Throwable) {
recordDuration(startTime, success = false)
logger.warn("Terminated by unexpected exception", t)
processException(t, false)
return
val resultOrError = try {
val result = logic.call()
// TODO expose maySkipCheckpoint here
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = false)
Try.Success(result)
} catch (throwable: Throwable) {
logger.warn("Flow threw exception", throwable)
Try.Failure<R>(throwable)
}
val finalEvent = when (resultOrError) {
is Try.Success -> {
Event.FlowFinish(resultOrError.value)
}
is Try.Failure -> {
Event.Error(resultOrError.exception)
}
}
processEvent(getTransientField(TransientValues::transitionExecutor), finalEvent)
processEventsUntilFlowIsResumed()
recordDuration(startTime)
// This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd(Try.Success(result), false)
_resultFuture.set(result)
logic.progressTracker?.currentStep = ProgressTracker.DONE
logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" }
}
private fun createTransaction() {
// Make sure we have a database transaction
database.createTransaction()
logger.trace { "Starting database transaction ${DatabaseTransactionManager.currentOrNull()} on ${Strand.currentStrand()}" }
@Suspendable
private fun initialiseFlow() {
processEventsUntilFlowIsResumed()
}
private fun processException(exception: Throwable, propagated: Boolean) {
actionOnEnd(Try.Failure(exception), propagated)
_resultFuture.setException(exception)
logic.progressTracker?.endWithError(exception)
}
internal fun commitTransaction() {
val transaction = DatabaseTransactionManager.current()
try {
logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." }
transaction.commit()
} catch (e: SQLException) {
// TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly.
logger.error("Transaction commit failed: ${e.message}", e)
System.exit(1)
@Suspendable
override fun <R> subFlow(subFlow: FlowLogic<R>): R {
processEvent(getTransientField(TransientValues::transitionExecutor), Event.EnterSubFlow(subFlow.javaClass))
return try {
subFlow.call()
} finally {
transaction.close()
processEvent(getTransientField(TransientValues::transitionExecutor), Event.LeaveSubFlow)
}
}
@Suspendable
override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession {
val sessionKey = Pair(sessionFlow, otherParty)
if (openSessions.containsKey(sessionKey)) {
throw IllegalStateException(
"Attempted to initiateFlow() twice in the same InitiatingFlow $sessionFlow for the same party " +
"$otherParty. This isn't supported in this version of Corda. Alternatively you may " +
"initiate a new flow by calling initiateFlow() in an " +
"@${InitiatingFlow::class.java.simpleName} sub-flow."
)
}
val flowSession = FlowSessionImpl(otherParty)
createNewSession(otherParty, flowSession, sessionFlow)
flowSession.stateMachine = this
flowSession.sessionFlow = sessionFlow
return flowSession
}
@Suspendable
override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo {
val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated
return state.context
}
@Suspendable
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
sessionFlow: FlowLogic<*>,
retrySend: Boolean,
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
val receivedSessionData: ReceivedSessionMessage<SessionData> = if (session == null) {
val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
// Only do a receive here as the session init has carried the payload
receiveInternal(newSession, receiveType)
} else {
val sendData = createSessionData(session, payload)
sendAndReceiveInternal(session, sendData, receiveType)
}
logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" }
return receivedSessionData.checkPayloadIs(receiveType)
}
@Suspendable
override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party,
sessionFlow: FlowLogic<*>,
maySkipCheckpoint: Boolean): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
val session = getConfirmedSession(otherParty, sessionFlow)
val sessionData = receiveInternal<SessionData>(session, receiveType)
logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType)
}
private fun requireNonPrimitive(receiveType: Class<*>) {
require(!receiveType.isPrimitive) {
"Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class"
}
}
@Suspendable
override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) {
logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
if (session == null) {
// Don't send the payload again if it was already piggy-backed on a session init
initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false)
} else {
sendInternal(session, createSessionData(session, payload))
}
}
@Suspendable
override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction {
logger.debug { "waitForLedgerCommit($hash) ..." }
suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>))
val stx = serviceHub.validatedTransactions.getTransaction(hash)
if (stx != null) {
logger.debug { "Transaction $hash committed to ledger" }
return stx
}
// If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message is ErrorSessionEnd) {
session.erroredEnd(receivedMessage.message)
}
}
}
throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
}
// Provide a mechanism to sleep within a Strand without locking any transactional state.
// This checkpoints, since we cannot undo any database writes up to this point.
@Suspendable
override fun sleepUntil(until: Instant) {
suspend(Sleep(until, this))
override fun initiateFlow(party: Party): FlowSession {
val resume = processEvent(
getTransientField(TransientValues::transitionExecutor),
Event.InitiateFlow(party)
) as FlowContinuation.Resume
return resume.result as FlowSession
}
// TODO Dummy implementation of access to application specific permission controls and audit logging
@ -292,231 +227,43 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly<SessionData>>()
for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
}
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
}
return result
}
internal fun pushToLoggingContext() = context.pushToLoggingContext()
/**
* This method will suspend the state machine and wait for incoming session init response from other party.
*/
@Suspendable
private fun FlowSessionInternal.waitForConfirmation() {
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this, null)
if (sessionInitResponse is SessionConfirm) {
state = FlowSessionState.Initiated(
peerParty,
sessionInitResponse.initiatedSessionId,
FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName))
} else {
sessionInitResponse as SessionReject
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}")
}
}
private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData {
val sessionState = session.state
val peerSessionId = when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
}
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
}
@Suspendable
private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message))
private inline fun <reified M : ExistingSessionMessage> receiveInternal(
session: FlowSessionInternal,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType))
}
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
session: FlowSessionInternal,
message: SessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType))
}
@Suspendable
private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? {
val session = openSessions[Pair(sessionFlow, otherParty)] ?: return null
return when (session.state) {
is FlowSessionState.Uninitiated -> null
is FlowSessionState.Initiating -> {
session.waitForConfirmation()
session
}
is FlowSessionState.Initiated -> session
}
}
@Suspendable
private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal {
return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?:
initiateSession(otherParty, sessionFlow, null, waitForConfirmation = true)
}
private fun createNewSession(
otherParty: Party,
flowSession: FlowSession,
sessionFlow: FlowLogic<*>
) {
logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session
}
@Suspendable
private fun initiateSession(
otherParty: Party,
sessionFlow: FlowLogic<*>,
firstPayload: Any?,
waitForConfirmation: Boolean,
retryable: Boolean = false
): FlowSessionInternal {
val session = openSessions[Pair(sessionFlow, otherParty)] ?: throw IllegalStateException("Expected an Uninitiated session for $otherParty")
val state = session.state as? FlowSessionState.Uninitiated ?: throw IllegalStateException("Tried to initiate a session $session, but it's already initiating/initiated")
logger.trace { "Initiating a new session with ${state.otherParty}" }
session.state = FlowSessionState.Initiating(state.otherParty)
session.retryable = retryable
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.")
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()
}
return session
}
@Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
}
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@Suspendable
override fun invoke(request: FlowIORequest) {
suspend(request)
}
}
@Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
val polledMessage = session.receivedMessages.poll()
return if (polledMessage != null) {
if (this is SendAndReceive) {
// Since we've already received the message, we downgrade to a send only to get the payload out and not
// inadvertently block
suspend(SendOnly(session, message))
}
polledMessage
} else {
// Suspend while we wait for a receive
suspend(this)
session.receivedMessages.poll() ?:
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this")
}
}
private fun <M : ExistingSessionMessage> ReceivedSessionMessage<*>.confirmReceiveType(
receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
val session = receiveRequest.session
val receiveType = receiveRequest.receiveType
if (receiveType.isInstance(message)) {
return uncheckedCast(this)
} else if (message is SessionEnd) {
openSessions.values.remove(session)
if (message is ErrorSessionEnd) {
session.erroredEnd(message)
} else {
val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending a $expectedType")
}
} else {
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest")
}
}
private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing {
if (end.errorResponse != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(end.errorResponse as java.lang.Throwable).fillInStackTrace()
throw end.errorResponse
} else {
throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
}
}
@Suspendable
private fun suspend(ioRequest: FlowIORequest) {
// We have to pass the thread local database transaction across via a transient field as the fiber park
// swaps them out.
txTrampoline = DatabaseTransactionManager.setThreadLocalTx(null)
if (ioRequest is WaitingRequest)
waitingForResponse = ioRequest
var exceptionDuringSuspend: Throwable? = null
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext))
val transaction = extractThreadLocalTransaction()
val transitionExecutor = TransientReference(getTransientField(TransientValues::transitionExecutor))
parkAndSerialize { _, _ ->
logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
try {
DatabaseTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
actionOnSuspend(ioRequest)
} catch (t: Throwable) {
// Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
// resume the fiber just so that we can throw it when it's running.
exceptionDuringSuspend = t
logger.trace("Resuming so fiber can it terminate with the exception thrown during suspend process", t)
resume(scheduler)
}
DatabaseTransactionManager.setThreadLocalTx(transaction.value)
val event = try {
Event.Suspend(
ioRequest = ioRequest,
maySkipCheckpoint = maySkipCheckpoint,
fiber = this.serialize(context = serializationContext.value)
)
} catch (throwable: Throwable) {
Event.Error(throwable)
}
if (exceptionDuringSuspend == null && ioRequest is Sleep) {
// Sleep on the fiber. This will not sleep if it's in the past.
Strand.sleep(Duration.between(Instant.now(), ioRequest.until).toNanos(), TimeUnit.NANOSECONDS)
}
createTransaction()
// TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
// the fiber when exceptions occur inside a suspend.
exceptionDuringSuspend?.let { throw it }
logger.trace { "Resumed from $ioRequest" }
}
internal fun resume(scheduler: FiberScheduler) {
try {
if (fromCheckpoint) {
logger.info("Resumed from checkpoint")
fromCheckpoint = false
// We must commit the database transaction before returning from this closure, otherwise Quasar may schedule
// other fibers
require(processEvent(transitionExecutor.value, event) == FlowContinuation.ProcessEvents)
Fiber.unparkDeserialized(this, scheduler)
} else if (state == State.NEW) {
logger.trace("Started")
start()
} else {
Fiber.unpark(this, QUASAR_UNBLOCKER)
}
} catch (t: Throwable) {
logger.error("Error during resume", t)
return processEventsUntilFlowIsResumed() as R
}
@Suspendable
override fun scheduleEvent(event: Event) {
getTransientField(TransientValues::eventQueue).send(event)
}
override fun snapshot(): StateMachineState {
return transientState!!.value
}
override val stateMachine get() = getTransientField(TransientValues::stateMachine)
/**
* Records the duration of this flow from call() to completion or failure.
* Note that the duration will include the time the flow spent being parked, and not just the total

View File

@ -0,0 +1,21 @@
package net.corda.node.services.statemachine
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
/**
* A simple [FlowHospital] implementation that immediately triggers error propagation when a flow dirties.
*/
object PropagatingFlowHospital : FlowHospital {
private val log = loggerFor<PropagatingFlowHospital>()
override fun flowErrored(flowFiber: FlowFiber) {
log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" }
flowFiber.scheduleEvent(Event.StartErrorPropagation)
}
override fun flowCleaned(flowFiber: FlowFiber) {
throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered")
}
}

View File

@ -1,59 +1,122 @@
package net.corda.node.services.statemachine
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowException
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible
import net.corda.core.flows.FlowInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData
import java.io.IOException
import java.security.SecureRandom
/**
* A session between two flows is identified by two session IDs, the initiating and the initiated session ID.
* However after the session has been established the communication is symmetric. From then on we differentiate between
* the two session IDs with "source" ID (the ID from which we receive) and "sink" ID (the ID to which we send).
*
* Flow A (initiating) Flow B (initiated)
* initiatingId=sourceId=0
* send(Initiate(initiatingId=0)) -----> initiatingId=sinkId=0
* initiatedId=sourceId=1
* initiatedId=sinkId=1 <----- send(Confirm(initiatedId=1))
*/
@CordaSerializable
sealed class SessionMessage
@CordaSerializable
interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
data class SessionId(val toLong: Long) {
companion object {
fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong())
}
}
interface SessionInitResponse : ExistingSessionMessage {
val initiatorSessionId: Long
override val recipientSessionId: Long get() = initiatorSessionId
}
interface SessionEnd : ExistingSessionMessage
data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String,
/**
* The initial message to initiate a session with.
*
* @param initiatorSessionId the session ID of the initiator. On the sending side this is the *source* ID, on the
* receiving side this is the *sink* ID.
* @param initiationEntropy additional randomness to seed the initiated flow's deduplication ID.
* @param initiatorFlowClassName the class name to be used to determine the initiating-initiated mapping on the receiver
* side.
* @param flowVersion the version of the initiating flow.
* @param appName the name of the cordapp defining the initiating flow, or "corda" if it's a core flow.
* @param firstPayload the optional first payload.
*/
data class InitialSessionMessage(
val initiatorSessionId: SessionId,
val initiationEntropy: Long,
val initiatorFlowClassName: String,
val flowVersion: Int,
val appName: String,
val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long,
val flowVersion: Int,
val appName: String) : SessionInitResponse
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
val firstPayload: SerializedBytes<Any>?
) : SessionMessage() {
override fun toString() = "InitialSessionMessage(" +
"initiatorSessionId=$initiatorSessionId, " +
"initiationEntropy=$initiationEntropy, " +
"initiatorFlowClassName=$initiatorFlowClassName, " +
"appName=$appName, " +
"firstPayload=${firstPayload?.javaClass}" +
")"
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
/**
* A message sent when a session has been established already.
*
* @param recipientSessionId the recipient session ID. On the sending side this is the *sink* ID, on the receiving side
* this is the *source* ID.
* @param payload the rest of the message.
*/
data class ExistingSessionMessage(
val recipientSessionId: SessionId,
val payload: ExistingSessionMessagePayload
) : SessionMessage()
/**
* The payload of an [ExistingSessionMessage]
*/
@CordaSerializable
sealed class ExistingSessionMessagePayload
/**
* The confirmation message sent by the initiated side.
* @param initiatedSessionId the initiated session ID, the other half of [InitialSessionMessage.initiatorSessionId].
* This is the *source* ID on the sending(initiated) side, and the *sink* ID on the receiving(initiating) side.
*/
data class ConfirmSessionMessage(
val initiatedSessionId: SessionId,
val initiatedFlowInfo: FlowInfo
) : ExistingSessionMessagePayload()
/**
* A message containing flow-related data.
*
* @param payload the serialised payload.
*/
data class DataSessionMessage(val payload: SerializedBytes<Any>) : ExistingSessionMessagePayload() {
override fun toString() = "DataSessionMessage(payload=${payload.javaClass})"
}
/**
* A message indicating that an error has happened.
*
* @param flowException the exception that happened. This is null if the error condition wasn't revealed to the
* receiving side.
* @param errorId the ID of the source error. This is always specified to allow posteriori correlation of error conditions.
*/
data class ErrorSessionMessage(val flowException: FlowException?, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that a session initiation has failed.
*
* @param message a message describing the problem to the initator.
* @param errorId an error ID identifying this error condition.
*/
data class RejectSessionMessage(val message: String, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that the flow hosting the session has ended. Note that this message is strictly part of the
* session protocol, the flow may be removed before all counter-flows have ended.
*
* The sole purpose of this message currently is to provide diagnostic in cases where the two communicating flows'
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
*/
object EndSessionMessage : ExistingSessionMessagePayload()

View File

@ -1,9 +1,11 @@
package net.corda.node.services.statemachine
import net.corda.core.concurrent.CordaFuture
import net.corda.core.flows.FlowLogic
import net.corda.core.internal.FlowStateMachine
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.messaging.DataFeed
import net.corda.core.utilities.Try
import rx.Observable
@ -23,7 +25,6 @@ import rx.Observable
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
* TODO: Timeouts
* TODO: Surfacing of exceptions via an API and/or management UI
* TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt
* TODO: Don't store all active flows in memory, load from the database on demand.
*/
interface StateMachineManager {
@ -43,7 +44,11 @@ interface StateMachineManager {
* @param flowLogic The flow's code.
* @param context The context of the flow.
*/
fun <A> startFlow(flowLogic: FlowLogic<A>, context: InvocationContext): CordaFuture<FlowStateMachine<A>>
fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party? = null
): CordaFuture<FlowStateMachine<A>>
/**
* Represents an addition/removal of a state machine.
@ -74,3 +79,12 @@ interface StateMachineManager {
*/
val allStateMachines: List<FlowLogic<*>>
}
// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call
// these functions again
interface StateMachineManagerInternal {
fun signalFlowHasStarted(flowId: StateMachineRunId)
fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId)
fun removeSessionBindings(sessionIds: Set<SessionId>)
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
}

View File

@ -0,0 +1,232 @@
package net.corda.node.services.statemachine
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.AcknowledgeHandle
/**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
* persisted to the database ([Checkpoint]), and the rest, which is an in-memory-only state.
*
* @param checkpoint the persisted part of the state.
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
* @param unacknowledgedMessages the list of currently unacknowledged messages.
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
* to make [Event.DoRemainingWork] idempotent.
* @param isTransactionTracked true if a ledger transaction has been tracked as part of a
* [FlowIORequest.WaitForLedgerCommit]. This used is to make tracking idempotent.
* @param isAnyCheckpointPersisted true if at least a single checkpoint has been persisted. This is used to determine
* whether we should DELETE the checkpoint at the end of the flow.
* @param isStartIdempotent true if the start of the flow is idempotent, making the skipping of the initial checkpoint
* possible.
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
* work.
*/
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
data class StateMachineState(
val checkpoint: Checkpoint,
val flowLogic: FlowLogic<*>,
val unacknowledgedMessages: List<AcknowledgeHandle>,
val isFlowResumed: Boolean,
val isTransactionTracked: Boolean,
val isAnyCheckpointPersisted: Boolean,
val isStartIdempotent: Boolean,
val isRemoved: Boolean
)
/**
* @param invocationContext the initiator of the flow.
* @param ourIdentity the identity the flow is run as.
* @param sessions map of source session ID to session state.
* @param subFlowStack the stack of currently executing subflows.
* @param flowState the state of the flow itself, including the frozen fiber/FlowLogic.
* @param errorState the "dirtiness" state including the involved errors and their propagation status.
* @param numberOfSuspends the number of flow suspends due to IO API calls.
* @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs.
*/
data class Checkpoint(
val invocationContext: InvocationContext,
val ourIdentity: Party,
val sessions: SessionMap, // This must preserve the insertion order!
val subFlowStack: List<SubFlow>,
val flowState: FlowState,
val errorState: ErrorState,
val numberOfSuspends: Int,
val deduplicationSeed: String
) {
companion object {
fun create(
invocationContext: InvocationContext,
flowStart: FlowStart,
flowLogicClass: Class<FlowLogic<*>>,
frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
ourIdentity: Party,
deduplicationSeed: String
): Try<Checkpoint> {
return SubFlow.create(flowLogicClass).map { topLevelSubFlow ->
Checkpoint(
invocationContext = invocationContext,
ourIdentity = ourIdentity,
sessions = emptyMap(),
subFlowStack = listOf(topLevelSubFlow),
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
errorState = ErrorState.Clean,
numberOfSuspends = 0,
deduplicationSeed = deduplicationSeed
)
}
}
}
}
/**
* The state of a session.
*/
sealed class SessionState {
/**
* We haven't yet sent the initialisation message
*/
data class Uninitiated(
val party: Party,
val initiatingSubFlow: SubFlow.Initiating
) : SessionState()
/**
* We have sent the initialisation message but have not yet received a confirmation.
* @property rejectionError if non-null the initiation failed.
*/
data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
val rejectionError: FlowError?
) : SessionState()
/**
* We have received a confirmation, the peer party and session id is resolved.
* @property errors if not empty the session is in an errored state.
*/
data class Initiated(
val peerParty: Party,
val peerFlowInfo: FlowInfo,
val receivedMessages: List<DataSessionMessage>,
val initiatedState: InitiatedSessionState,
val errors: List<FlowError>
) : SessionState()
}
typealias SessionMap = Map<SessionId, SessionState>
/**
* Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest
* of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future
* [FlowInfo] requests.
*/
sealed class InitiatedSessionState {
data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState()
object Ended : InitiatedSessionState() { override fun toString() = "Ended" }
}
/**
* Represents the way the flow has started.
*/
sealed class FlowStart {
/**
* The flow was started explicitly e.g. through RPC or a scheduled state.
*/
object Explicit : FlowStart() { override fun toString() = "Explicit" }
/**
* The flow was started implicitly as part of session initiation.
*/
data class Initiated(
val peerSession: FlowSessionImpl,
val initiatedSessionId: SessionId,
val initiatingMessage: InitialSessionMessage,
val senderCoreFlowVersion: Int?,
val initiatedFlowInfo: FlowInfo
) : FlowStart() { override fun toString() = "Initiated" }
}
/**
* Represents the user-space related state of the flow.
*/
sealed class FlowState {
/**
* The flow's unstarted state. We should always be able to start a fresh flow fiber from this datastructure.
*
* @param flowStart How the flow was started.
* @param frozenFlowLogic The serialized user-provided [FlowLogic].
*/
data class Unstarted(
val flowStart: FlowStart,
val frozenFlowLogic: SerializedBytes<FlowLogic<*>>
) : FlowState() {
override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash}"
}
/**
* The flow's started state, this means the user-code has suspended on an IO request.
*
* @param flowIORequest what IO request the flow has suspended on.
* @param frozenFiber the serialized fiber itself.
*/
data class Started(
val flowIORequest: FlowIORequest<*>,
val frozenFiber: SerializedBytes<FlowStateMachineImpl<*>>
) : FlowState() {
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash}"
}
}
/**
* @param errorId the ID of the error. This is generated once for the source error and is propagated to neighbour
* sessions.
* @param exception the exception itself. Note that this may not contain information about the source error depending
* on whether the source error was a FlowException or otherwise.
*/
data class FlowError(val errorId: Long, val exception: Throwable)
/**
* The flow's error state.
*/
sealed class ErrorState {
abstract fun addErrors(newErrors: List<FlowError>): ErrorState
/**
* The flow is in a clean state.
*/
object Clean : ErrorState() {
override fun addErrors(newErrors: List<FlowError>): ErrorState {
return Errored(newErrors, 0, false)
}
override fun toString() = "Clean"
}
/**
* The flow has dirtied because of an uncaught exception from user code or other error condition during a state
* transition.
* @param errors the list of errors. Multiple errors may be associated with the errored flow e.g. when multiple
* sessions are errored and have been waited on.
* @param propagatedIndex the index of the first error that hasn't yet been propagated.
* @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the
* sessions associated with the flow have been (or about to be) dirtied in counter-flows.
*/
data class Errored(
val errors: List<FlowError>,
val propagatedIndex: Int,
val propagating: Boolean
) : ErrorState() {
override fun addErrors(newErrors: List<FlowError>): ErrorState {
return copy(errors = errors + newErrors)
}
}
}

View File

@ -0,0 +1,74 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.utilities.Try
/**
* A [SubFlow] contains metadata about a currently executing sub-flow. At any point the flow execution is
* characterised with a stack of [SubFlow]s. This stack is used to determine the initiating-initiated flow mapping.
*
* Note that Initiat*ed*ness is an orthogonal property of the top-level subflow, so we don't store any information about
* it here.
*/
sealed class SubFlow {
abstract val flowClass: Class<out FlowLogic<*>>
/**
* An inlined subflow.
*/
data class Inlined(override val flowClass: Class<FlowLogic<*>>) : SubFlow()
/**
* An initiating subflow.
* @param [flowClass] the concrete class of the subflow.
* @param [classToInitiateWith] an ancestor class of [flowClass] with the [InitiatingFlow] annotation, to be sent
* to the initiated side.
* @param flowInfo the [FlowInfo] associated with the initiating flow.
*/
data class Initiating(
override val flowClass: Class<FlowLogic<*>>,
val classToInitiateWith: Class<in FlowLogic<*>>,
val flowInfo: FlowInfo
) : SubFlow()
companion object {
fun create(flowClass: Class<FlowLogic<*>>): Try<SubFlow> {
// Are we an InitiatingFlow?
val initiatingAnnotations = getInitiatingFlowAnnotations(flowClass)
return when (initiatingAnnotations.size) {
0 -> {
Try.Success(Inlined(flowClass))
}
1 -> {
val initiatingAnnotation = initiatingAnnotations[0]
val flowContext = FlowInfo(initiatingAnnotation.second.version, flowClass.appName)
Try.Success(Initiating(flowClass, initiatingAnnotation.first, flowContext))
}
else -> {
Try.Failure(IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated " +
"once, however the following classes all have the annotation: " +
"${initiatingAnnotations.map { it.first }}"))
}
}
}
private fun <C> getSuperClasses(clazz: Class<C>): List<Class<in C>> {
var currentClass: Class<in C>? = clazz
val result = ArrayList<Class<in C>>()
while (currentClass != null) {
result.add(currentClass)
currentClass = currentClass.superclass
}
return result
}
private fun getInitiatingFlowAnnotations(flowClass: Class<FlowLogic<*>>): List<Pair<Class<in FlowLogic<*>>, InitiatingFlow>> {
return getSuperClasses(flowClass).mapNotNull { clazz ->
val initiatingAnnotation = clazz.getDeclaredAnnotation(InitiatingFlow::class.java)
initiatingAnnotation?.let { Pair(clazz, it) }
}
}
}
}

View File

@ -0,0 +1,25 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
/**
* An executor of state machine transitions. This is mostly a wrapper interface around an [ActionExecutor], but can be
* used to create interceptors of transitions.
*/
interface TransitionExecutor {
@Suspendable
fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState>
}
/**
* An interceptor of a transition. These are currently explicitly hooked up in [StateMachineManagerImpl].
*/
typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor

View File

@ -0,0 +1,66 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import java.security.SecureRandom
/**
* This [TransitionExecutor] runs the transition actions using the passed in [ActionExecutor] and manually dirties the
* state on failure.
*
* If a failure happens when we're already transitioning into a errored state then the transition and the flow fiber is
* completely aborted to avoid error loops.
*/
class TransitionExecutorImpl(
val secureRandom: SecureRandom,
val database: CordaPersistence
) : TransitionExecutor {
private companion object {
val log = contextLogger()
}
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
DatabaseTransactionManager.dataSource = database
for (action in transition.actions) {
try {
actionExecutor.executeAction(fiber, action)
} catch (exception: Throwable) {
DatabaseTransactionManager.currentOrNull()?.close()
if (transition.newState.checkpoint.errorState is ErrorState.Errored) {
// If we errored while transitioning to an error state then we cannot record the additional
// error as that may result in an infinite loop, e.g. error propagation fails -> record error -> propagate fails again.
// Instead we just keep around the old error state and wait for a new schedule, perhaps
// triggered from a flow hospital
log.error("Error while executing $action during transition to errored state, aborting transition", exception)
return Pair(FlowContinuation.Abort, previousState.copy(isFlowResumed = false))
} else {
// Otherwise error the state manually keeping the old flow state and schedule a DoRemainingWork
// to trigger error propagation
log.error("Error while executing $action, erroring state", exception)
val newState = previousState.copy(
checkpoint = previousState.checkpoint.copy(
errorState = previousState.checkpoint.errorState.addErrors(
listOf(FlowError(secureRandom.nextLong(), exception))
)
),
isFlowResumed = false
)
fiber.scheduleEvent(Event.DoRemainingWork)
return Pair(FlowContinuation.ProcessEvents, newState)
}
}
}
return Pair(transition.continuation, transition.newState)
}
}

View File

@ -0,0 +1,47 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.util.concurrent.ConcurrentHashMap
/**
* This interceptor records a trace of all of the flows' states and transitions. If the flow dirties it dumps the trace
* transition to the logger.
*/
class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
companion object {
private val log = contextLogger()
}
private val records = ConcurrentHashMap<StateMachineRunId, ArrayList<TransitionDiagnosticRecord>>()
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val transitionRecord = TransitionDiagnosticRecord(fiber.id, previousState, nextState, event, transition, continuation)
val record = records.compute(fiber.id) { _, record ->
(record ?: ArrayList()).apply { add(transitionRecord) }
}
if (nextState.checkpoint.errorState is ErrorState.Errored) {
log.warn("Flow ${fiber.id} dirtied, dumping all transitions:\n${record!!.joinToString("\n")}")
}
if (transition.newState.isRemoved) {
records.remove(fiber.id)
}
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,94 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.util.concurrent.LinkedBlockingQueue
import kotlin.concurrent.thread
/**
* This interceptor checks whether a checkpointed fiber state can be deserialised in a separate thread.
*/
class FiberDeserializationCheckingInterceptor(
val fiberDeserializationChecker: FiberDeserializationChecker,
val delegate: TransitionExecutor
) : TransitionExecutor {
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val previousFlowState = previousState.checkpoint.flowState
val nextFlowState = nextState.checkpoint.flowState
if (nextFlowState is FlowState.Started) {
if (previousFlowState !is FlowState.Started || previousFlowState.frozenFiber != nextFlowState.frozenFiber) {
fiberDeserializationChecker.submitCheck(nextFlowState.frozenFiber)
}
}
return Pair(continuation, nextState)
}
}
/**
* A fiber deserialisation checker thread. It checks the queued up serialised checkpoints to see if they can be
* deserialised. This is only run in development mode to allow detecting of corrupt serialised checkpoints before they
* are actually used.
*/
class FiberDeserializationChecker {
companion object {
val log = contextLogger()
}
private sealed class Job {
class Check(val serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) : Job()
object Finish : Job()
}
private var checkerThread: Thread? = null
private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) {
require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) {
val job = jobQueue.take()
when (job) {
is Job.Check -> {
try {
job.serializedFiber.deserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true
}
}
Job.Finish -> {
return@thread
}
}
}
}
}
fun submitCheck(serializedFiber: SerializedBytes<FlowStateMachineImpl<*>>) {
jobQueue.add(Job.Check(serializedFiber))
}
/**
* Returns true if some unrestorable checkpoints were encountered, false otherwise
*/
fun stop(): Boolean {
jobQueue.add(Job.Finish)
checkerThread?.join()
return foundUnrestorableFibers
}
}

View File

@ -0,0 +1,43 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
import java.util.concurrent.ConcurrentHashMap
/**
* This interceptor notifies the passed in [flowHospital] in case a flow went through a clean->errored or a errored->clean
* transition.
*/
class HospitalisingInterceptor(
private val flowHospital: FlowHospital,
private val delegate: TransitionExecutor
) : TransitionExecutor {
private val hospitalisedFlows = ConcurrentHashMap<StateMachineRunId, FlowFiber>()
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
when (nextState.checkpoint.errorState) {
ErrorState.Clean -> {
if (hospitalisedFlows.remove(fiber.id) != null) {
flowHospital.flowCleaned(fiber)
}
}
is ErrorState.Errored -> {
if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) {
flowHospital.flowErrored(fiber)
}
}
}
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,30 @@
package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.transitions.TransitionResult
/**
* This interceptor simply prints all state machine transitions. Useful for debugging.
*/
class PrintingInterceptor(val delegate: TransitionExecutor) : TransitionExecutor {
companion object {
val log = contextLogger()
}
@Suspendable
override fun executeTransition(
fiber: FlowFiber,
previousState: StateMachineState,
event: Event,
transition: TransitionResult,
actionExecutor: ActionExecutor
): Pair<FlowContinuation, StateMachineState> {
val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor)
val transitionRecord = TransitionDiagnosticRecord(fiber.id, previousState, nextState, event, transition, continuation)
log.info("Transition for flow ${fiber.id} $transitionRecord")
return Pair(continuation, nextState)
}
}

View File

@ -0,0 +1,48 @@
package net.corda.node.services.statemachine.interceptors
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.transitions.FlowContinuation
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.StateMachineState
import net.corda.node.services.statemachine.transitions.TransitionResult
import net.corda.node.utilities.ObjectDiffer
/**
* This is a diagnostic record that stores information about a state machine transition and provides pretty printing
* by diffing the two states.
*/
data class TransitionDiagnosticRecord(
val flowId: StateMachineRunId,
val previousState: StateMachineState,
val nextState: StateMachineState,
val event: Event,
val transition: TransitionResult,
val continuation: FlowContinuation
) {
override fun toString(): String {
val diffIntended = ObjectDiffer.diff(previousState, transition.newState)
val diffNext = ObjectDiffer.diff(previousState, nextState)
return (
listOf(
"",
" --- Transition of flow $flowId ---",
" Event: $event",
" Actions: ",
" ${transition.actions.joinToString("\n ")}",
" Continuation: ${transition.continuation}"
) +
if (diffIntended != diffNext) {
listOf(
" Diff between previous and intended state:",
"${diffIntended?.toPaths()?.joinToString("")}"
)
} else {
emptyList()
} + listOf(
" Diff between previous and next state:",
"${diffNext?.toPaths()?.joinToString("")}"
)
).joinToString("\n")
}
}

View File

@ -0,0 +1,189 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.node.services.statemachine.*
/**
* This transition handles incoming session messages. It handles the following cases:
* - DataSessionMessage: these arrive to initiated and confirmed sessions and are expected to be received by the flow.
* - ConfirmSessionMessage: these arrive as a response to an InitialSessionMessage and include information about the
* counterparty flow's session ID as well as their [FlowInfo].
* - ErrorSessionMessage: these arrive to initiated and confirmed sessions and put the corresponding session into an
* "errored" state. This means that whenever that session is subsequently interacted with the error will be thrown
* in the flow.
* - RejectSessionMessage: these arrive as a response to an InitialSessionMessage when the initiation failed. It
* behaves similarly to ErrorSessionMessage aside from the type of exceptions stored/raised.
* - EndSessionMessage: these are sent when the counterparty flow has finished. They put the corresponding session into
* an "ended" state. This means that subsequent sends on this session will fail, and receives will start failing
* after the buffer of already received messages is drained.
*/
class DeliverSessionMessageTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val event: Event.DeliverSessionMessage
) : Transition {
override fun transition(): TransitionResult {
return builder {
// Add the AcknowledgeHandle to the unacknowledged messages ASAP so in case an error happens we still know
// about the message. Note that in case of an error during deliver this message *will be acked*.
// For example if the session corresponding to the message is not found the message is still acked to free
// up the broker but the flow will error.
currentState = currentState.copy(
unacknowledgedMessages = currentState.unacknowledgedMessages + event.acknowledgeHandle
)
// Check whether we have a session corresponding to the message.
val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId]
if (existingSession == null) {
freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId))
} else {
val payload = event.sessionMessage.payload
// Dispatch based on what kind of message it is.
val _exhaustive = when (payload) {
is ConfirmSessionMessage -> confirmMessageTransition(existingSession, payload)
is DataSessionMessage -> dataMessageTransition(existingSession, payload)
is ErrorSessionMessage -> errorMessageTransition(existingSession, payload)
is RejectSessionMessage -> rejectMessageTransition(existingSession, payload)
is EndSessionMessage -> endMessageTransition()
}
}
if (!isErrored()) {
persistCheckpointIfNeeded()
}
// Schedule a DoRemainingWork to check whether the flow needs to be woken up.
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
FlowContinuation.ProcessEvents
}
}
private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) {
// We received a confirmation message. The corresponding session state must be Initiating.
when (sessionState) {
is SessionState.Initiating -> {
// Create the new session state that is now Initiated.
val initiatedSession = SessionState.Initiated(
peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(),
initiatedState = InitiatedSessionState.Live(message.initiatedSessionId),
errors = emptyList()
)
val newCheckpoint = currentState.checkpoint.copy(
sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession)
)
// Send messages that were buffered pending confirmation of session.
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId)
}
actions.addAll(sendActions)
currentState = currentState.copy(checkpoint = newCheckpoint)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
// We received a data message. The corresponding session must be Initiated.
return when (sessionState) {
is SessionState.Initiated -> {
// Buffer the message in the session's receivedMessages buffer.
val newSessionState = sessionState.copy(
receivedMessages = sessionState.receivedMessages + message
)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState)
)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
val exception: Throwable = if (payload.flowException == null) {
UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId)
} else {
payload.flowException.originalErrorId = payload.errorId
payload.flowException
}
return when (sessionState) {
is SessionState.Initiated -> {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception)
val newSessionState = sessionState.copy(errors = sessionState.errors + flowError)
currentState = currentState.copy(
checkpoint = checkpoint.copy(
sessions = checkpoint.sessions + (sessionId to newSessionState)
)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.rejectMessageTransition(sessionState: SessionState, payload: RejectSessionMessage) {
val exception = UnexpectedFlowEndException(payload.message, cause = null, originalErrorId = payload.errorId)
return when (sessionState) {
is SessionState.Initiating -> {
if (sessionState.rejectionError != null) {
// Double reject
freshErrorTransition(UnexpectedEventInState())
} else {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception)
currentState = currentState.copy(
checkpoint = checkpoint.copy(
sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError))
)
)
}
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.persistCheckpointIfNeeded() {
// We persist the message as soon as it arrives.
if (context.configuration.sessionDeliverPersistenceStrategy == SessionDeliverPersistenceStrategy.OnDeliver &&
event.sessionMessage.payload !is EndSessionMessage) {
actions.addAll(arrayOf(
Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages)
))
currentState = currentState.copy(
unacknowledgedMessages = emptyList(),
isAnyCheckpointPersisted = true
)
}
}
private fun TransitionBuilder.endMessageTransition() {
val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.sessions
val sessionState = sessions[sessionId]
if (sessionState == null) {
return freshErrorTransition(CannotFindSessionException(sessionId))
}
when (sessionState) {
is SessionState.Initiated -> {
val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = sessions + (sessionId to newSessionState)
)
)
}
else -> {
freshErrorTransition(UnexpectedEventInState())
}
}
}
}

View File

@ -0,0 +1,37 @@
package net.corda.node.services.statemachine.transitions
import net.corda.node.services.statemachine.*
/**
* This transition checks the current state of the flow and determines whether anything needs to be done.
*/
class DoRemainingWorkTransition(
override val context: TransitionContext,
override val startingState: StateMachineState
) : Transition {
override fun transition(): TransitionResult {
val checkpoint = startingState.checkpoint
// If the flow is removed or has been resumed don't do work.
if (startingState.isFlowResumed || startingState.isRemoved) {
return TransitionResult(startingState)
}
// Check whether the flow is errored
return when (checkpoint.errorState) {
is ErrorState.Clean -> cleanTransition()
is ErrorState.Errored -> erroredTransition(checkpoint.errorState)
}
}
// If the flow is clean check the FlowState
private fun cleanTransition(): TransitionResult {
val checkpoint = startingState.checkpoint
return when (checkpoint.flowState) {
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition()
is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition()
}
}
private fun erroredTransition(errorState: ErrorState.Errored): TransitionResult {
return ErrorFlowTransition(context, startingState, errorState).transition()
}
}

View File

@ -0,0 +1,124 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.node.services.statemachine.*
/**
* This transition defines what should happen when a flow has errored.
*
* In general there are two flow-level error conditions:
*
* - Internal exceptions. These may arise due to problems in the flow framework or errors during state machine
* transitions e.g. network or database failure.
* - User-raised exceptions. These are exceptions that are (re)raised in user code, allowing the user to catch them.
* These may come from illegal flow API calls, and FlowExceptions or other counterparty failures that are re-raised
* when the flow tries to use the corresponding sessions.
*
* Both internal exceptions and uncaught user-raised exceptions cause the flow to be errored. This flags the flow as
* unable to be resumed. When a flow is in this state an external source (e.g. Flow hospital) may decide to
*
* 1. Retry it (not implemented yet). This throws away the errored state and re-tries from the last clean checkpoint.
* 2. Start error propagation. This seals the flow as errored permanently and propagates the associated error(s) to
* all live sessions. This causes these sessions to errored on the other side, which may in turn cause the
* counter-flows themselves to errored.
*
* See [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor] for how to detect flow errors.
*
* Note that in general we handle multiple errors at a time as several error conditions may arise at the same time and
* new errors may arise while the flow is in the errored state already.
*/
class ErrorFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
private val errorState: ErrorState.Errored
) : Transition {
override fun transition(): TransitionResult {
val allErrors: List<FlowError> = errorState.errors
val remainingErrorsToPropagate: List<FlowError> = allErrors.subList(errorState.propagatedIndex, allErrors.size)
val errorMessages: List<ErrorSessionMessage> = remainingErrorsToPropagate.map(this::createErrorMessageFromError)
return builder {
// If we're errored and propagating do the actual propagation and update the index.
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages)
val newCheckpoint = startingState.checkpoint.copy(
errorState = errorState.copy(propagatedIndex = allErrors.size),
sessions = newSessions
)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions.add(Action.PropagateErrors(errorMessages, initiatedSessions))
}
// If we're errored but not propagating keep processing events.
if (remainingErrorsToPropagate.isNotEmpty() && !errorState.propagating) {
return@builder FlowContinuation.ProcessEvents
}
// If we haven't been removed yet remove the flow.
if (!currentState.isRemoved) {
actions.add(Action.CreateTransaction)
if (currentState.isAnyCheckpointPersisted) {
actions.add(Action.RemoveCheckpoint(context.id))
}
actions.addAll(arrayOf(
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages),
Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys)
))
currentState = currentState.copy(
unacknowledgedMessages = emptyList(),
isRemoved = true
)
val removalReason = FlowRemovalReason.ErrorFinish(allErrors)
actions.add(Action.RemoveFlow(context.id, removalReason, currentState))
FlowContinuation.Abort
} else {
// Otherwise keep processing events. This branch happens when there are some outstanding initiating
// sessions that prevent the removal of the flow.
FlowContinuation.ProcessEvents
}
}
}
private fun createErrorMessageFromError(error: FlowError): ErrorSessionMessage {
val exception = error.exception
// If the exception doesn't contain an originalErrorId that means it's a fresh FlowException that should
// propagate to the neighbouring flows. If it has the ID filled in that means it's a rethrown FlowException and
// shouldn't be propagated.
return if (exception is FlowException && exception.originalErrorId == null) {
ErrorSessionMessage(flowException = exception, errorId = error.errorId)
} else {
ErrorSessionMessage(flowException = null, errorId = error.errorId)
}
}
// Buffer error messages in Initiating sessions, return the initialised ones.
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
}
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
} else {
sessionState
}
}
val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && session.errors.isEmpty()) {
session
} else {
null
}
}
return Pair(initiatedSessions, newSessions)
}
}

View File

@ -0,0 +1,399 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowSession
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.statemachine.*
/**
* This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request
* is persisted (unless checkpoint was skipped) and the user-space DB transaction is commited.
*
* Before this transition we either did a checkpoint or the checkpoint was restored from the database.
*/
class StartedFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val started: FlowState.Started
) : Transition {
override fun transition(): TransitionResult {
val flowIORequest = started.flowIORequest
val checkpoint = startingState.checkpoint
val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint)
if (errorsToThrow.isNotEmpty()) {
return TransitionResult(
newState = startingState.copy(isFlowResumed = true),
// throw the first exception. TODO should this aggregate all of them somehow?
actions = listOf(Action.CreateTransaction),
continuation = FlowContinuation.Throw(errorsToThrow[0])
)
}
return when (flowIORequest) {
is FlowIORequest.Send -> sendTransition(flowIORequest)
is FlowIORequest.Receive -> receiveTransition(flowIORequest)
is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest)
is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest)
is FlowIORequest.Sleep -> sleepTransition(flowIORequest)
is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest)
is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition()
}
}
private fun waitForSessionConfirmationsTransition(): TransitionResult {
return builder {
if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(Unit)
}
}
}
private fun getFlowInfoTransition(flowIORequest: FlowIORequest.GetFlowInfo): TransitionResult {
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for (session in flowIORequest.sessions) {
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
}
return builder {
// Initialise uninitialised sessions in order to receive the associated FlowInfo. Some or all sessions may
// not be initialised yet.
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
val flowInfoMap = getFlowInfoFromSessions(sessionIdToSession)
if (flowInfoMap == null) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(flowInfoMap)
}
}
}
private fun TransitionBuilder.getFlowInfoFromSessions(sessionIdToSession: Map<SessionId, FlowSessionImpl>): Map<FlowSession, FlowInfo>? {
val checkpoint = currentState.checkpoint
val resultMap = LinkedHashMap<FlowSession, FlowInfo>()
for ((sessionId, session) in sessionIdToSession) {
val sessionState = checkpoint.sessions[sessionId]
if (sessionState is SessionState.Initiated) {
resultMap[session] = sessionState.peerFlowInfo
} else {
return null
}
}
return resultMap
}
private fun sleepTransition(flowIORequest: FlowIORequest.Sleep): TransitionResult {
return builder {
actions.add(Action.SleepUntil(flowIORequest.wakeUpAfter))
resumeFlowLogic(Unit)
}
}
private fun waitForLedgerCommitTransition(flowIORequest: FlowIORequest.WaitForLedgerCommit): TransitionResult {
return if (!startingState.isTransactionTracked) {
TransitionResult(
newState = startingState.copy(isTransactionTracked = true),
actions = listOf(
Action.CreateTransaction,
Action.TrackTransaction(flowIORequest.hash),
Action.CommitTransaction
)
)
} else {
TransitionResult(startingState)
}
}
private fun sendAndReceiveTransition(flowIORequest: FlowIORequest.SendAndReceive): TransitionResult {
val sessionIdToMessage = LinkedHashMap<SessionId, SerializedBytes<Any>>()
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for ((session, message) in flowIORequest.sessionToMessage) {
val sessionId = (session as FlowSessionImpl).sourceSessionId
sessionIdToMessage[sessionId] = message
sessionIdToSession[sessionId] = session
}
return builder {
sendToSessionsTransition(sessionIdToMessage)
if (isErrored()) {
FlowContinuation.ProcessEvents
} else {
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
if (receivedMap == null) {
// We don't yet have the messages, change the suspension to be on Receive
val newIoRequest = FlowIORequest.Receive(flowIORequest.sessionToMessage.keys.toNonEmptySet())
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
flowState = FlowState.Started(newIoRequest, started.frozenFiber)
)
)
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(receivedMap)
}
}
}
}
private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult {
return builder {
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
for (session in flowIORequest.sessions) {
sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session
}
// send initialises to uninitialised sessions
sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys)
val receivedMap = receiveFromSessionsTransition(sessionIdToSession)
if (receivedMap == null) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(receivedMap)
}
}
}
private fun TransitionBuilder.receiveFromSessionsTransition(
sourceSessionIdToSessionMap: Map<SessionId, FlowSessionImpl>
): Map<FlowSession, SerializedBytes<Any>>? {
val checkpoint = currentState.checkpoint
val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null
val resultMap = LinkedHashMap<FlowSession, SerializedBytes<Any>>()
for ((sessionId, message) in pollResult.messages) {
val session = sourceSessionIdToSessionMap[sessionId]!!
resultMap[session] = message
}
currentState = currentState.copy(
checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap)
)
return resultMap
}
data class PollResult(
val messages: Map<SessionId, SerializedBytes<Any>>,
val newSessionMap: SessionMap
)
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
val newSessionMessages = LinkedHashMap(sessions)
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
var someNotFound = false
for (sessionId in sessionIds) {
val sessionState = sessions[sessionId]
when (sessionState) {
is SessionState.Initiated -> {
val messages = sessionState.receivedMessages
if (messages.isEmpty()) {
someNotFound = true
} else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
resultMessages[sessionId] = messages[0].payload
}
}
else -> {
someNotFound = true
}
}
}
return if (someNotFound) {
return null
} else {
PollResult(resultMessages, newSessionMessages)
}
}
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.sessions)
var index = 0
for (sourceSessionId in sourceSessions) {
val sessionState = checkpoint.sessions[sourceSessionId]
if (sessionState == null) {
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
}
if (sessionState !is SessionState.Uninitiated) {
continue
}
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null)
actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
)
}
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
}
private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult {
return builder {
val sessionIdToMessage = flowIORequest.sessionToMessage.mapKeys {
sessionToSessionId(it.key)
}
sendToSessionsTransition(sessionIdToMessage)
if (isErrored()) {
FlowContinuation.ProcessEvents
} else {
resumeFlowLogic(Unit)
}
}
}
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap(checkpoint.sessions)
var index = 0
for ((sourceSessionId, message) in sourceSessionIdToMessage) {
val existingSessionState = checkpoint.sessions[sourceSessionId]
if (existingSessionState == null) {
return freshErrorTransition(CannotFindSessionException(sourceSessionId))
} else {
val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++)
val _exhaustive = when (existingSessionState) {
is SessionState.Uninitiated -> {
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message)
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null
)
Unit
}
is SessionState.Initiating -> {
// We're initiating this session, buffer the message
val newBufferedMessages = existingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
newSessions[sourceSessionId] = existingSessionState.copy(bufferedMessages = newBufferedMessages)
}
is SessionState.Initiated -> {
when (existingSessionState.initiatedState) {
is InitiatedSessionState.Live -> {
val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId))
Unit
}
InitiatedSessionState.Ended -> {
return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId"))
}
}
}
}
}
}
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
}
private fun sessionToSessionId(session: FlowSession): SessionId {
return (session as FlowSessionImpl).sourceSessionId
}
private fun collectErroredSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.flatMap { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Uninitiated -> emptyList()
is SessionState.Initiating -> {
if (sessionState.rejectionError == null) {
emptyList()
} else {
listOf(sessionState.rejectionError.exception)
}
}
is SessionState.Initiated -> sessionState.errors.map(FlowError::exception)
}
}
}
private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> {
return checkpoint.sessions.values.mapNotNull { sessionState ->
(sessionState as? SessionState.Initiating)?.rejectionError?.exception
}
}
private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.mapNotNull { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
if (sessionState.initiatedState is InitiatedSessionState.Ended) {
UnexpectedFlowEndException(
"Tried to access ended session $sessionId",
cause = null,
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
}
}
private fun collectEndedEmptySessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.mapNotNull { sessionId ->
val sessionState = checkpoint.sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
if (sessionState.initiatedState is InitiatedSessionState.Ended &&
sessionState.receivedMessages.isEmpty()) {
UnexpectedFlowEndException(
"Tried to access ended session $sessionId with empty buffer",
cause = null,
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
}
}
private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List<Throwable> {
return when (flowIORequest) {
is FlowIORequest.Send -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.Receive -> {
val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.SendAndReceive -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint)
}
is FlowIORequest.WaitForLedgerCommit -> {
collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint)
}
is FlowIORequest.GetFlowInfo -> {
collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint)
}
is FlowIORequest.Sleep -> {
emptyList()
}
is FlowIORequest.WaitForSessionConfirmations -> {
collectErroredInitiatingSessionErrors(checkpoint)
}
}
}
private fun createInitialSessionMessage(
initiatingSubFlow: SubFlow.Initiating,
sourceSessionId: SessionId,
payload: SerializedBytes<Any>?
): InitialSessionMessage {
return InitialSessionMessage(
initiatorSessionId = sourceSessionId,
// We add additional entropy to add to the initiated side's deduplication seed.
initiationEntropy = context.secureRandom.nextLong(),
initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name,
flowVersion = initiatingSubFlow.flowInfo.flowVersion,
appName = initiatingSubFlow.flowInfo.appName,
firstPayload = payload
)
}
}

View File

@ -0,0 +1,30 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.*
import net.corda.node.services.statemachine.*
import java.security.SecureRandom
enum class SessionDeliverPersistenceStrategy {
OnDeliver,
OnNextCommit
}
data class StateMachineConfiguration(
val sessionDeliverPersistenceStrategy: SessionDeliverPersistenceStrategy
) {
companion object {
val default = StateMachineConfiguration(
sessionDeliverPersistenceStrategy = SessionDeliverPersistenceStrategy.OnDeliver
)
}
}
class StateMachine(
val id: StateMachineRunId,
val configuration: StateMachineConfiguration,
val secureRandom: SecureRandom
) {
fun transition(event: Event, state: StateMachineState): TransitionResult {
return TopLevelTransition(TransitionContext(id, configuration, secureRandom), state, event).transition()
}
}

View File

@ -0,0 +1,236 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.InitiatingFlow
import net.corda.core.internal.FlowIORequest
import net.corda.core.utilities.Try
import net.corda.node.services.statemachine.*
/**
* This is the top level event-handling transition function capable of handling any [Event].
*
* It is a *pure* function taking a state machine state and an event, returning the next state along with a list of IO
* actions to execute.
*/
class TopLevelTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val event: Event
) : Transition {
override fun transition(): TransitionResult {
return when (event) {
is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition()
is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition()
is Event.Error -> errorTransition(event)
is Event.TransactionCommitted -> transactionCommittedTransition(event)
is Event.SoftShutdown -> softShutdownTransition()
is Event.StartErrorPropagation -> startErrorPropagationTransition()
is Event.EnterSubFlow -> enterSubFlowTransition(event)
is Event.LeaveSubFlow -> leaveSubFlowTransition()
is Event.Suspend -> suspendTransition(event)
is Event.FlowFinish -> flowFinishTransition(event)
is Event.InitiateFlow -> initiateFlowTransition(event)
}
}
private fun errorTransition(event: Event.Error): TransitionResult {
return builder {
freshErrorTransition(event.exception)
FlowContinuation.ProcessEvents
}
}
private fun transactionCommittedTransition(event: Event.TransactionCommitted): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
if (currentState.isTransactionTracked &&
checkpoint.flowState is FlowState.Started &&
checkpoint.flowState.flowIORequest is FlowIORequest.WaitForLedgerCommit &&
checkpoint.flowState.flowIORequest.hash == event.transaction.id) {
currentState = currentState.copy(isTransactionTracked = false)
if (isErrored()) {
return@builder FlowContinuation.ProcessEvents
}
resumeFlowLogic(event.transaction)
} else {
freshErrorTransition(UnexpectedEventInState())
FlowContinuation.ProcessEvents
}
}
}
private fun softShutdownTransition(): TransitionResult {
val lastState = startingState.copy(isRemoved = true)
return TransitionResult(
newState = lastState,
actions = listOf(
Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys),
Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState)
),
continuation = FlowContinuation.Abort
)
}
private fun startErrorPropagationTransition(): TransitionResult {
return builder {
val errorState = currentState.checkpoint.errorState
when (errorState) {
ErrorState.Clean -> freshErrorTransition(UnexpectedEventInState())
is ErrorState.Errored -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
errorState = errorState.copy(propagating = true)
)
)
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
}
}
FlowContinuation.ProcessEvents
}
}
private fun enterSubFlowTransition(event: Event.EnterSubFlow): TransitionResult {
return builder {
val subFlow = SubFlow.create(event.subFlowClass)
when (subFlow) {
is Try.Success -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value
)
)
}
is Try.Failure -> {
freshErrorTransition(subFlow.exception)
}
}
FlowContinuation.ProcessEvents
}
}
private fun leaveSubFlowTransition(): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
if (checkpoint.subFlowStack.isEmpty()) {
freshErrorTransition(UnexpectedEventInState())
} else {
currentState = currentState.copy(
checkpoint = checkpoint.copy(
subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList()
)
)
}
FlowContinuation.ProcessEvents
}
}
private fun suspendTransition(event: Event.Suspend): TransitionResult {
return builder {
val newCheckpoint = currentState.checkpoint.copy(
flowState = FlowState.Started(event.ioRequest, event.fiber),
numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1
)
if (event.maySkipCheckpoint) {
actions.addAll(arrayOf(
Action.CommitTransaction,
Action.ScheduleEvent(Event.DoRemainingWork)
))
currentState = currentState.copy(
checkpoint = newCheckpoint,
isFlowResumed = false
)
} else {
actions.addAll(arrayOf(
Action.PersistCheckpoint(context.id, newCheckpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages),
Action.ScheduleEvent(Event.DoRemainingWork)
))
currentState = currentState.copy(
checkpoint = newCheckpoint,
unacknowledgedMessages = emptyList(),
isFlowResumed = false,
isAnyCheckpointPersisted = true
)
}
FlowContinuation.ProcessEvents
}
}
private fun flowFinishTransition(event: Event.FlowFinish): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
when (checkpoint.errorState) {
ErrorState.Clean -> {
val unacknowledgedMessages = currentState.unacknowledgedMessages
currentState = currentState.copy(
checkpoint = checkpoint.copy(
numberOfSuspends = checkpoint.numberOfSuspends + 1
),
unacknowledgedMessages = emptyList(),
isFlowResumed = false,
isRemoved = true
)
val allSourceSessionIds = checkpoint.sessions.keys
if (currentState.isAnyCheckpointPersisted) {
actions.add(Action.RemoveCheckpoint(context.id))
}
actions.addAll(arrayOf(
Action.PersistDeduplicationIds(unacknowledgedMessages),
Action.CommitTransaction,
Action.AcknowledgeMessages(unacknowledgedMessages),
Action.RemoveSessionBindings(allSourceSessionIds),
Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState)
))
sendEndMessages()
// Resume to end fiber
FlowContinuation.Resume(null)
}
is ErrorState.Errored -> {
currentState = currentState.copy(isFlowResumed = false)
actions.add(Action.RollbackTransaction)
FlowContinuation.ProcessEvents
}
}
}
}
private fun TransitionBuilder.sendEndMessages() {
val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state ->
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index)
Action.SendExisting(state.peerParty, message, deduplicationId)
} else {
null
}
}.filterNotNull()
actions.addAll(sendEndMessageActions)
}
private fun initiateFlowTransition(event: Event.InitiateFlow): TransitionResult {
return builder {
val checkpoint = currentState.checkpoint
val initiatingSubFlow = getClosestAncestorInitiatingSubFlow(checkpoint)
if (initiatingSubFlow == null) {
freshErrorTransition(IllegalStateException("Tried to initiate in a flow not annotated with @${InitiatingFlow::class.java.simpleName}"))
return@builder FlowContinuation.ProcessEvents
}
val sourceSessionId = SessionId.createRandom(context.secureRandom)
val sessionImpl = FlowSessionImpl(event.party, sourceSessionId)
val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow))
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
FlowContinuation.Resume(sessionImpl)
}
}
private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? {
for (subFlow in checkpoint.subFlowStack.asReversed()) {
if (subFlow is SubFlow.Initiating) {
return subFlow
}
}
return null
}
}

View File

@ -0,0 +1,32 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom
/**
* An interface used to separate out different parts of the state machine transition function.
*/
interface Transition {
/** The context of the transition. */
val context: TransitionContext
/** The state the transition is starting in. */
val startingState: StateMachineState
/** The (almost) pure transition function. The only side-effect we allow is random number generation. */
fun transition(): TransitionResult
/**
* A helper
*/
fun builder(build: TransitionBuilder.() -> FlowContinuation): TransitionResult {
val builder = TransitionBuilder(context, startingState)
val continuation = build(builder)
return TransitionResult(builder.currentState, builder.actions, continuation)
}
}
class TransitionContext(
val id: StateMachineRunId,
val configuration: StateMachineConfiguration,
val secureRandom: SecureRandom
)

View File

@ -0,0 +1,74 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.IdentifiableException
import net.corda.node.services.statemachine.*
// This is a file defining some common utilities for creating state machine transitions.
/**
* A builder that helps creating [Transition]s. This allows for a more imperative style of specifying the transition.
*/
class TransitionBuilder(val context: TransitionContext, initialState: StateMachineState) {
/** The current state machine state of the builder */
var currentState = initialState
/** The list of actions to execute */
val actions = ArrayList<Action>()
/** Check if [currentState] state is errored */
fun isErrored(): Boolean = currentState.checkpoint.errorState is ErrorState.Errored
/**
* Transition the builder into an error state because of a fresh error that happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun freshErrorTransition(error: Throwable) {
val flowError = FlowError(
errorId = (error as? IdentifiableException)?.errorId ?: context.secureRandom.nextLong(),
exception = error
)
errorTransition(flowError)
}
/**
* Transition the builder into an error state because of a list of errors that happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun errorsTransition(errors: List<FlowError>) {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
errorState = currentState.checkpoint.errorState.addErrors(errors)
),
isFlowResumed = false
)
actions.clear()
actions.addAll(arrayOf(
Action.RollbackTransaction,
Action.ScheduleEvent(Event.DoRemainingWork)
))
}
/**
* Transition the builder into an error state because of a non-fresh error has happened.
* Existing actions and the current state are thrown away, and the initial state is dirtied.
*
* @param error the error.
*/
fun errorTransition(error: FlowError) {
errorsTransition(listOf(error))
}
fun resumeFlowLogic(result: Any?): FlowContinuation {
actions.add(Action.CreateTransaction)
currentState = currentState.copy(isFlowResumed = true)
return FlowContinuation.Resume(result)
}
}
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
class UnexpectedEventInState : IllegalStateException("Unexpected event")

View File

@ -0,0 +1,46 @@
package net.corda.node.services.statemachine.transitions
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.StateMachineState
/**
* A datastructure capturing the intended new state of the flow, the actions to be executed as part of the transition
* and a [FlowContinuation].
*
* Read this datastructure as an instruction to the state machine executor:
* "Transition to [newState] *if* [actions] execute cleanly. If so, use [continuation] to decide what to do next. If
* there was an error it's up to you what to do".
* Also see [net.corda.node.services.statemachine.TransitionExecutorImpl] on how this is interpreted.
*/
data class TransitionResult(
val newState: StateMachineState,
val actions: List<Action> = emptyList(),
val continuation: FlowContinuation = FlowContinuation.ProcessEvents
)
/**
* A datastructure describing what to do after a transition has succeeded.
*/
sealed class FlowContinuation {
/**
* Return to user code with the supplied [result].
*/
data class Resume(val result: Any?) : FlowContinuation() {
override fun toString() = "Resume(result=${result?.javaClass})"
}
/**
* Throw an exception [throwable] in user code.
*/
data class Throw(val throwable: Throwable) : FlowContinuation()
/**
* Keep processing pending events.
*/
object ProcessEvents : FlowContinuation() { override fun toString() = "ProcessEvents" }
/**
* Immediately abort the flow. Note that this does not imply an error condition.
*/
object Abort : FlowContinuation() { override fun toString() = "Abort" }
}

View File

@ -0,0 +1,80 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo
import net.corda.node.services.statemachine.*
/**
* This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and
* initialises the initiated session in case the flow is an initiated one.
*/
class UnstartedFlowTransition(
override val context: TransitionContext,
override val startingState: StateMachineState,
val unstarted: FlowState.Unstarted
) : Transition {
override fun transition(): TransitionResult {
return builder {
if (!currentState.isAnyCheckpointPersisted && !currentState.isStartIdempotent) {
createInitialCheckpoint()
}
actions.add(Action.SignalFlowHasStarted(context.id))
if (unstarted.flowStart is FlowStart.Initiated) {
initialiseInitiatedSession(unstarted.flowStart)
}
currentState = currentState.copy(isFlowResumed = true)
actions.add(Action.CreateTransaction)
FlowContinuation.Resume(null)
}
}
// Initialise initiated session, store initial payload, send confirmation back.
private fun TransitionBuilder.initialiseInitiatedSession(flowStart: FlowStart.Initiated) {
val initiatingMessage = flowStart.initiatingMessage
val initiatedState = SessionState.Initiated(
peerParty = flowStart.peerSession.counterparty,
initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId),
peerFlowInfo = FlowInfo(
flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion,
appName = initiatingMessage.appName
),
receivedMessages = if (initiatingMessage.firstPayload == null) {
emptyList()
} else {
listOf(DataSessionMessage(initiatingMessage.firstPayload))
},
errors = emptyList()
)
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy(
sessions = mapOf(flowStart.initiatedSessionId to initiatedState)
)
)
actions.add(
Action.SendExisting(
flowStart.peerSession.counterparty,
sessionMessage,
DeduplicationId.createForNormal(currentState.checkpoint, 0)
)
)
}
// Create initial checkpoint and acknowledge triggering messages.
private fun TransitionBuilder.createInitialCheckpoint() {
actions.addAll(arrayOf(
Action.CreateTransaction,
Action.PersistCheckpoint(context.id, currentState.checkpoint),
Action.PersistDeduplicationIds(currentState.unacknowledgedMessages),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.unacknowledgedMessages)
))
currentState = currentState.copy(
unacknowledgedMessages = emptyList(),
isAnyCheckpointPersisted = true
)
}
}

View File

@ -4,9 +4,13 @@ import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowLogic
import net.corda.core.node.services.VaultService
import net.corda.core.utilities.*
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.trace
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager
import java.util.*
class VaultSoftLockManager private constructor(private val vault: VaultService) {
@ -48,11 +52,15 @@ class VaultSoftLockManager private constructor(private val vault: VaultService)
private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet<StateRef>) {
log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" }
DatabaseTransactionManager.dataSource.transaction {
vault.softLockReserve(flowId, stateRefs)
}
}
private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) {
log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" }
DatabaseTransactionManager.dataSource.transaction {
vault.softLockRelease(flowId)
}
}
}

View File

@ -0,0 +1,144 @@
package net.corda.node.utilities
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.lang.reflect.Type
import java.time.Instant
/**
* A tree describing the diff between two objects.
*
* For example:
* data class A(val field1: Int, val field2: String, val field3: Unit)
* fun main(args: Array<String>) {
* val someA = A(1, "hello", Unit)
* val someOtherA = A(2, "bello", Unit)
* println(ObjectDiffer.diff(someA, someOtherA))
* }
*
* Will give back Step(branches=[(field1, Last(a=1, b=2)), (field2, Last(a=hello, b=bello))])
*/
sealed class DiffTree {
/**
* Describes a "step" from the object root. It contains a list of field-subtree pairs.
*/
data class Step(val branches: List<Pair<String, DiffTree>>) : DiffTree()
/**
* Describes the leaf of the diff. This is either where the diffing was cutoff (e.g. primitives) or where it failed.
*/
data class Last(val a: Any?, val b: Any?) : DiffTree()
/**
* Flattens the [DiffTree] into a list of [DiffPath]s
*/
fun toPaths(): List<DiffPath> {
return when (this) {
is Step -> branches.flatMap { (step, tree) -> tree.toPaths().map { it.copy(path = listOf(step) + it.path) } }
is Last -> listOf(DiffPath(emptyList(), a, b))
}
}
}
/**
* A diff focused on a single [DiffTree.Last] diff, including the path leading there.
*/
data class DiffPath(
val path: List<String>,
val a: Any?,
val b: Any?
) {
override fun toString(): String {
return "${path.joinToString(".")}: \n $a\n $b\n"
}
}
/**
* This is a very simple differ used to diff objects of any kind, to be used for diagnostic.
*/
object ObjectDiffer {
fun diff(a: Any?, b: Any?): DiffTree? {
if (a == null || b == null) {
if (a == b) {
return null
} else {
return DiffTree.Last(a, b)
}
}
if (a != b) {
if (a.javaClass.isPrimitive || a.javaClass in diffCutoffClasses) {
return DiffTree.Last(a, b)
}
// TODO deduplicate this code
if (a is Map<*, *> && b is Map<*, *>) {
val allKeys = a.keys + b.keys
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
if (branches.isEmpty()) {
return null
} else {
return DiffTree.Step(branches)
}
}
if (a is java.util.Map<*, *> && b is java.util.Map<*, *>) {
val allKeys = a.keySet() + b.keySet()
val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } }
if (branches.isEmpty()) {
return null
} else {
return DiffTree.Step(branches)
}
}
val aFields = getFieldFoci(a)
val bFields = getFieldFoci(b)
try {
if (aFields != bFields) {
return DiffTree.Last(a, b)
} else {
// TODO need to account for cases where the fields don't match up (different subclasses)
val branches = aFields.map { field -> diff(field.get(a), field.get(b))?.let { field.name to it } }.filterNotNull()
if (branches.isEmpty()) {
return DiffTree.Last(a, b)
} else {
return DiffTree.Step(branches)
}
}
} catch (throwable: Exception) {
Exception("Error while diffing $a with $b", throwable).printStackTrace(System.out)
return DiffTree.Last(a, b)
}
} else {
return null
}
}
// List of types to cutoff the diffing at.
private val diffCutoffClasses: Set<Class<*>> = setOf(
String::class.java,
Class::class.java,
Instant::class.java
)
// A type capturing the accessor to a field. This is a separate abstraction to simple reflection as we identify
// getX() and isX() calls as fields as well.
private data class FieldFocus(val name: String, val type: Type, val getter: Method) {
fun get(obj: Any): Any? {
return getter.invoke(obj)
}
}
private fun getFieldFoci(obj: Any) : List<FieldFocus> {
val foci = ArrayList<FieldFocus>()
for (method in obj.javaClass.declaredMethods) {
if (Modifier.isStatic(method.modifiers)) {
continue
}
if (method.name.startsWith("get") && method.name.length > 3 && method.parameterCount == 0) {
val fieldName = method.name[3].toLowerCase() + method.name.substring(4)
foci.add(FieldFocus(fieldName, method.returnType, method))
} else if (method.name.startsWith("is") && method.parameterCount == 0) {
foci.add(FieldFocus(method.name, method.returnType, method))
}
}
return foci
}
}

View File

@ -2,7 +2,6 @@ package net.corda.node.messaging
import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.TopicStringValidator
import net.corda.node.services.messaging.createMessage
import net.corda.testing.node.MockNetwork
import org.junit.After
import org.junit.Before
@ -48,10 +47,10 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var finalDelivery: Message? = null
node2.network.addMessageHandler { msg, _ ->
node2.network.addMessageHandler("test.topic") { msg, _, _ ->
node2.network.send(msg, node3.network.myAddress)
}
node3.network.addMessageHandler { msg, _ ->
node3.network.addMessageHandler("test.topic") { msg, _, _ ->
finalDelivery = msg
}
@ -60,7 +59,7 @@ class InMemoryMessagingTests {
mockNet.runNetwork(rounds = 1)
assertTrue(Arrays.equals(finalDelivery!!.data, bits))
assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits))
}
@Test
@ -72,7 +71,7 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler { _, _ -> counter++ } }
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), mockNet.messagingNetwork.everyoneOnline)
mockNet.runNetwork(rounds = 1)
assertEquals(3, counter)
@ -88,12 +87,12 @@ class InMemoryMessagingTests {
val node2 = mockNet.createNode()
var received = 0
node1.network.addMessageHandler("valid_message") { _, _ ->
node1.network.addMessageHandler("valid_message") { _, _, _ ->
received++
}
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(0))
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(0))
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1))
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1))
node2.network.send(invalidMessage, node1.network.myAddress)
mockNet.runNetwork()
assertEquals(0, received)
@ -104,8 +103,8 @@ class InMemoryMessagingTests {
// Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so
// this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops
val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(0))
val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(0))
val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1))
val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1))
node2.network.send(invalidMessage2, node1.network.myAddress)
node2.network.send(validMessage2, node1.network.myAddress)
mockNet.runNetwork()

View File

@ -720,6 +720,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
private val database: CordaPersistence,
private val delegate: WritableTransactionStorage
) : WritableTransactionStorage, SingletonSerializeAsToken() {
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return database.transaction {
delegate.trackTransaction(id)
}
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return database.transaction {
delegate.track()

View File

@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.MetricRegistry
import com.nhaarman.mockito_kotlin.*
import net.corda.core.contracts.*
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRef
import net.corda.core.flows.FlowLogicRefFactory
@ -117,7 +118,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
}
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
mockSMM = StateMachineManagerImpl(services, DBCheckpointStorage(), smmExecutor, database)
mockSMM = StateMachineManagerImpl(services, DBCheckpointStorage(), smmExecutor, database, newSecureRandom())
scheduler = NodeSchedulerService(testClock, database, FlowStarterImpl(smmExecutor, mockSMM), stateLoader, schedulerGatedExecutor, serverThread = smmExecutor)
mockSMM.changes.subscribe { change ->
if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) {

View File

@ -1,6 +1,10 @@
package net.corda.node.services.messaging
import net.corda.core.crypto.generateKeyPair
import net.corda.core.concurrent.CordaFuture
import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.generateKeyPair
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.RPCUserService
import net.corda.node.services.RPCUserServiceImpl
@ -118,7 +122,7 @@ class ArtemisMessagingTests {
messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take()
assertEquals("first msg", String(actual.data))
assertEquals("first msg", String(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS))
}
@ -143,7 +147,8 @@ class ArtemisMessagingTests {
val messagingClient = createMessagingClient(platformVersion = platformVersion)
startNodeMessagingClient()
messagingClient.addMessageHandler(TOPIC) { message, _ ->
messagingClient.addMessageHandler(TOPIC) { message, _, handle ->
handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages]
receivedMessages.add(message)
}
// Run after the handlers are added, otherwise (some of) the messages get delivered and discarded / dead-lettered.

View File

@ -1,13 +1,19 @@
package net.corda.node.services.persistence
import com.google.common.primitives.Ints
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.core.serialization.serialize
import net.corda.node.internal.configureDatabase
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.ALICE
import net.corda.testing.LogHelper
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
@ -17,14 +23,11 @@ import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import kotlin.streams.toList
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
val checkpoints = mutableListOf<Checkpoint>()
forEach {
checkpoints += it
true
}
return checkpoints
internal fun CheckpointStorage.checkpoints(): List<SerializedBytes<Checkpoint>> {
val checkpoints = getAllCheckpoints().toList()
return checkpoints.map { it.second }
}
class DBCheckpointStorageTests {
@ -50,9 +53,9 @@ class DBCheckpointStorageTests {
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
@ -65,12 +68,12 @@ class DBCheckpointStorageTests {
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).isEmpty()
@ -83,12 +86,12 @@ class DBCheckpointStorageTests {
@Test
fun `add and remove checkpoint in single commit operate`() {
val checkpoint = newCheckpoint()
val checkpoint2 = newCheckpoint()
val (id, checkpoint) = newCheckpoint()
val (id2, checkpoint2) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(checkpoint2)
checkpointStorage.removeCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(id, checkpoint)
checkpointStorage.addCheckpoint(id2, checkpoint2)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
@ -101,16 +104,16 @@ class DBCheckpointStorageTests {
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
val (id, firstCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(firstCheckpoint)
checkpointStorage.addCheckpoint(id, firstCheckpoint)
}
val secondCheckpoint = newCheckpoint()
val (id2, secondCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.addCheckpoint(id2, secondCheckpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(firstCheckpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
@ -123,9 +126,9 @@ class DBCheckpointStorageTests {
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
val (id, originalCheckpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(originalCheckpoint)
checkpointStorage.addCheckpoint(id, originalCheckpoint)
}
newCheckpointStorage()
val reconstructedCheckpoint = database.transaction {
@ -135,7 +138,7 @@ class DBCheckpointStorageTests {
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
}
database.transaction {
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
checkpointStorage.removeCheckpoint(id)
}
database.transaction {
assertThat(checkpointStorage.checkpoints()).isEmpty()
@ -148,7 +151,14 @@ class DBCheckpointStorageTests {
}
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
private fun newCheckpoint(): Pair<StateMachineRunId, SerializedBytes<Checkpoint>> {
val id = StateMachineRunId.createRandom()
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {}
}
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, "").getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
}
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import co.paralleluniverse.strands.concurrent.Semaphore
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState
@ -48,6 +49,7 @@ import rx.Notification
import rx.Observable
import java.time.Instant
import java.util.*
import java.util.concurrent.ExecutionException
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -110,6 +112,19 @@ class FlowFrameworkTests {
assertThat(flow.lazyTime).isNotNull()
}
class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor {
var thrown = false
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
if (thrown) {
delegate.executeAction(fiber, action)
} else {
thrown = true
throw exception
}
}
}
@Test
fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
@ -117,16 +132,15 @@ class FlowFrameworkTests {
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend")
fiber.actionOnSuspend = {
throw exceptionDuringSuspend
}
val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork()
assertThatThrownBy {
fiber.resultFuture.getOrThrow()
}.isSameAs(exceptionDuringSuspend)
assertThat(aliceNode.smm.allStateMachines).isEmpty()
// Make sure the fiber does actually terminate
assertThat(fiber.isTerminated).isTrue()
assertThat(fiber.state).isEqualTo(Strand.State.WAITING)
}
@Test
@ -217,6 +231,8 @@ class FlowFrameworkTests {
val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
@ -235,9 +251,6 @@ class FlowFrameworkTests {
aliceNode sent normalEnd to charlieNode
//There's no session end from the other flows as they're manually suspended
)
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
}
@Test
@ -294,14 +307,16 @@ class FlowFrameworkTests {
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive
}
}
@Test
fun `receiving unexpected session end before entering sendAndReceive`() {
bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() }
receivedSessionMessagesObservable().filter {
it.message is ExistingSessionMessage && it.message.payload is EndSessionMessage
}.subscribe { sessionEndReceived.release() }
val resultFuture = aliceNode.services.startFlow(
WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture
mockNet.runNetwork()
@ -337,7 +352,9 @@ class FlowFrameworkTests {
mockNet.runNetwork()
assertThat(erroringFlowSteps.get()).containsExactly(
erroringFlowFuture.getOrThrow()
val flowSteps = erroringFlowSteps.get()
assertThat(flowSteps).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlowFuture.get().exceptionThrown)
)
@ -354,7 +371,7 @@ class FlowFrameworkTests {
assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd() to aliceNode
bobNode sent errorMessage() to aliceNode
)
}
@ -377,8 +394,8 @@ class FlowFrameworkTests {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
}
assertThat(receivingFiber.isTerminated).isTrue()
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue()
assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING)
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).state).isEqualTo(Strand.State.WAITING)
assertThat(erroringFlowSteps.get()).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlow.get().exceptionThrown)
@ -387,14 +404,15 @@ class FlowFrameworkTests {
assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd(erroringFlow.get().exceptionThrown) to aliceNode
bobNode sent errorMessage(erroringFlow.get().exceptionThrown) to aliceNode
)
// Make sure the original stack trace isn't sent down the wire
assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty()
}
@Test
fun `FlowException propagated in invocation chain`() {
fun `FlowException only propagated to parent`() {
val charlieNode = mockNet.createNode(MockNodeParameters(legalName = CHARLIE_NAME))
val charlie = charlieNode.info.singleIdentity()
@ -402,9 +420,8 @@ class FlowFrameworkTests {
bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork()
assertThatExceptionOfType(MyFlowException::class.java)
assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
.withMessage("Chain")
}
@Test
@ -436,7 +453,7 @@ class FlowFrameworkTests {
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData("Hello") to aliceNode,
aliceNode sent erroredEnd() to bobNode
aliceNode sent errorMessage() to bobNode
)
}
@ -556,10 +573,8 @@ class FlowFrameworkTests {
@Test
fun `customised client flow which has annotated @InitiatingFlow again`() {
val result = aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
result.getOrThrow()
assertThatExceptionOfType(ExecutionException::class.java).isThrownBy {
aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
}.withMessageContaining(InitiatingFlow::class.java.simpleName)
}
@ -601,20 +616,20 @@ class FlowFrameworkTests {
@Test
fun `unknown class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class")
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("Don't know not.a.real.Class")
}
@Test
fun `non-flow class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow")
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("${String::class.java.name} is not a flow")
}
@Test
@ -633,24 +648,6 @@ class FlowFrameworkTests {
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
}
@Test
fun `double initiateFlow throws`() {
val future = aliceNode.services.startFlow(DoubleInitiatingFlow()).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(IllegalStateException::class.java)
.isThrownBy { future.getOrThrow() }
.withMessageContaining("Attempted to initiateFlow() twice")
}
@InitiatingFlow
private class DoubleInitiatingFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() {
initiateFlow(ourIdentity)
initiateFlow(ourIdentity)
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers
@ -680,19 +677,18 @@ class FlowFrameworkTests {
return observable.toFuture()
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)
private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address)
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
}
}
@ -707,7 +703,9 @@ class FlowFrameworkTests {
}
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
val isPayloadTransfer: Boolean get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to"
}
@ -716,7 +714,7 @@ class FlowFrameworkTests {
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map {
return filter { it.message.topic == FlowMessagingImpl.sessionTopic }.map {
val from = it.sender.id
val message = it.message.data.deserialize<SessionMessage>()
SessionTransfer(from, sanitise(message), it.recipients)
@ -724,12 +722,23 @@ class FlowFrameworkTests {
}
private fun sanitise(message: SessionMessage) = when (message) {
is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0, appName = "")
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "")
is NormalSessionEnd -> message.copy(recipientSessionId = 0)
is ErrorSessionEnd -> message.copy(recipientSessionId = 0)
else -> message
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "")
is ExistingSessionMessage -> {
val payload = message.payload
message.copy(
recipientSessionId = SessionId(0),
payload = when (payload) {
is ConfirmSessionMessage -> payload.copy(
initiatedSessionId = SessionId(0),
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
)
is ErrorSessionMessage -> payload.copy(
errorId = 0
)
else -> payload
}
)
}
}
private infix fun StartedNode<MockNode>.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)

View File

@ -27,6 +27,7 @@ import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.StateMachineTransactionMapping
import net.corda.core.node.services.Vault
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction
@ -37,10 +38,10 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.unwrap
import net.corda.node.internal.StartedNode
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.statemachine.Checkpoint
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.*
import net.corda.testing.node.*
@ -56,21 +57,14 @@ import java.io.ByteArrayOutputStream
import java.util.*
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import kotlin.streams.toList
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
/**
* Copied from DBCheckpointStorageTests as it is required as helper for this test
*/
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
val checkpoints = mutableListOf<Checkpoint>()
forEach {
checkpoints += it
true
}
return checkpoints
internal fun CheckpointStorage.checkpoints(): List<SerializedBytes<Checkpoint>> {
val checkpoints = getAllCheckpoints().toList()
return checkpoints.map { it.second }
}
/**
@ -740,6 +734,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
private val database: CordaPersistence,
private val delegate: WritableTransactionStorage
) : WritableTransactionStorage, SingletonSerializeAsToken() {
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return database.transaction {
delegate.trackTransaction(id)
}
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return database.transaction {
delegate.track()

View File

@ -15,9 +15,7 @@ import net.corda.core.serialization.deserialize
import net.corda.core.utilities.ProgressTracker
import net.corda.netmap.VisualiserViewModel.Style
import net.corda.netmap.simulation.IRSSimulation
import net.corda.node.services.statemachine.SessionConfirm
import net.corda.node.services.statemachine.SessionEnd
import net.corda.node.services.statemachine.SessionInit
import net.corda.node.services.statemachine.*
import net.corda.testing.chooseIdentity
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MockNetwork
@ -342,12 +340,16 @@ class NetworkMapVisualiser : Application() {
private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
// Loopback messages are boring.
if (transfer.sender == transfer.recipients) return false
val message = transfer.message.data.deserialize<Any>()
val message = transfer.message.data.deserialize<SessionMessage>()
return when (message) {
is SessionEnd -> false
is SessionConfirm -> false
is SessionInit -> message.firstPayload != null
else -> true
is InitialSessionMessage -> message.firstPayload != null
is ExistingSessionMessage -> when (message.payload) {
is ConfirmSessionMessage -> false
is DataSessionMessage -> true
is ErrorSessionMessage -> true
is RejectSessionMessage -> true
is EndSessionMessage -> false
}
}
}
}

View File

@ -20,6 +20,7 @@ include 'experimental:sandbox'
include 'experimental:quasar-hook'
include 'experimental:kryo-hook'
include 'experimental:intellij-plugin'
include 'experimental:flow-hook'
include 'verifier'
include 'test-common'
include 'test-utils'

View File

@ -634,7 +634,7 @@ class DriverDSL(
throw ListenProcessDeathException(rpcAddress, processDeathFuture.getOrThrow())
}
val connection = connectionFuture.getOrThrow()
shutdownManager.registerShutdown(connection::close)
shutdownManager.registerShutdown(connection::forceClose)
connection.proxy
}
}

View File

@ -13,9 +13,12 @@ import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.trace
import net.corda.node.services.messaging.*
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging
@ -57,7 +60,7 @@ class InMemoryMessagingNetwork(
@CordaSerializable
data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) {
override fun toString() = "${message.topicSession} from '$sender' to '$recipients'"
override fun toString() = "${message.topic} from '$sender' to '$recipients'"
}
// All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false
@ -242,17 +245,17 @@ class InMemoryMessagingNetwork(
_sentMessages.onNext(transfer)
}
data class InMemoryMessage(override val topicSession: TopicSession,
override val data: ByteArray,
override val uniqueMessageId: UUID,
data class InMemoryMessage(override val topic: String,
override val data: ByteSequence,
override val uniqueMessageId: DeduplicationId,
override val debugTimestamp: Instant = Instant.now()) : Message {
override fun toString() = "$topicSession#${String(data)}"
override fun toString() = "$topic#${String(data.bytes)}"
}
private data class InMemoryReceivedMessage(override val topicSession: TopicSession,
override val data: ByteArray,
private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteSequence,
override val platformVersion: Int,
override val uniqueMessageId: UUID,
override val uniqueMessageId: DeduplicationId,
override val debugTimestamp: Instant,
override val peer: CordaX500Name) : ReceivedMessage
@ -268,8 +271,7 @@ class InMemoryMessagingNetwork(
private val peerHandle: PeerHandle,
private val executor: AffinityExecutor,
private val database: CordaPersistence) : SingletonSerializeAsToken(), MessagingService {
inner class Handler(val topicSession: TopicSession,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration
@Volatile
private var running = true
@ -280,7 +282,7 @@ class InMemoryMessagingNetwork(
}
private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
override val myAddress: PeerHandle get() = peerHandle
@ -302,13 +304,10 @@ class InMemoryMessagingNetwork(
}
}
override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration {
check(running)
val (handler, transfers) = state.locked {
val handler = Handler(topicSession, callback).apply { handlers.add(this) }
val handler = Handler(topic, callback).apply { handlers.add(this) }
val pending = ArrayList<MessageTransfer>()
database.transaction {
pending.addAll(pendingRedelivery)
@ -354,8 +353,8 @@ class InMemoryMessagingNetwork(
override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
return InMemoryMessage(topicSession, data, uuid)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
}
/**
@ -388,14 +387,14 @@ class InMemoryMessagingNetwork(
while (deliverTo == null) {
val transfer = (if (block) q.take() else q.poll()) ?: return null
deliverTo = state.locked {
val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topicSession == it.topicSession }
val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession }
if (matchingHandlers.isEmpty()) {
// Got no handlers for this message yet. Keep the message around and attempt redelivery after a new
// handler has been registered. The purpose of this path is to make unit tests that have multi-threading
// reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting
// up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at
// least sometimes.
log.warn("Message to ${transfer.message.topicSession} could not be delivered")
log.warn("Message to ${transfer.message.topic} could not be delivered")
database.transaction {
pendingRedelivery.add(transfer)
}
@ -419,7 +418,13 @@ class InMemoryMessagingNetwork(
database.transaction {
for (handler in deliverTo) {
try {
handler.callback(transfer.toReceivedMessage(), handler)
val acknowledgeHandle = object : AcknowledgeHandle {
override fun acknowledge() {
}
override fun persistDeduplicationId() {
}
}
handler.callback(transfer.toReceivedMessage(), handler, acknowledgeHandle)
} catch (e: Exception) {
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
}
@ -436,8 +441,8 @@ class InMemoryMessagingNetwork(
}
private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage(
message.topicSession,
message.data.copyOf(), // Kryo messes with the buffer so give each client a unique copy
message.topic,
OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy
1,
message.uniqueMessageId,
message.debugTimestamp,

View File

@ -4,12 +4,14 @@ import com.google.common.collect.MutableClassToInstanceMap
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import com.typesafe.config.ConfigParseOptions
import net.corda.core.concurrent.CordaFuture
import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.*
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
@ -17,6 +19,7 @@ import net.corda.core.node.*
import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction
import net.corda.node.VersionInfo
import net.corda.node.internal.StateLoaderImpl
@ -285,6 +288,10 @@ class MockStateMachineRecordedTransactionMappingStorage(
) : StateMachineRecordedTransactionMappingStorage by storage
open class MockTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() {
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return txns[id]?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture()
}
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return DataFeed(txns.values.toList(), _updatesPublisher)
}

View File

@ -1,6 +1,5 @@
package com.r3.corda.jmeter
import com.sun.javaws.exceptions.InvalidArgumentException
import net.corda.core.internal.div
import org.apache.jmeter.JMeter
import org.slf4j.LoggerFactory
@ -68,7 +67,7 @@ class Launcher {
if (args[index] == "-XsshUser") {
++index
if (index == args.size || args[index].startsWith("-")) {
throw InvalidArgumentException(args)
throw IllegalArgumentException(args.toList().toString())
}
userName = args[index]
} else if (args[index] == "-Xssh") {

View File

@ -2,7 +2,6 @@ package com.r3.corda.jmeter
import com.jcraft.jsch.JSch
import com.jcraft.jsch.Session
import com.sun.javaws.exceptions.InvalidArgumentException
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.addShutdownHook
import org.slf4j.LoggerFactory
@ -28,7 +27,7 @@ class Ssh {
if (args[index] == "-XsshUser") {
++index
if (index == args.size || args[index].startsWith("-")) {
throw InvalidArgumentException(args)
throw IllegalArgumentException(args.toList().toString())
}
userName = args[index]
} else if (args[index] == "-Xssh") {