diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 84b4d41817..0b6a177fae 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -491,6 +491,7 @@ public final class net.corda.core.contracts.StateAndContract extends java.lang.O public final class net.corda.core.contracts.Structures extends java.lang.Object @org.jetbrains.annotations.NotNull public static final net.corda.core.crypto.SecureHash hash(net.corda.core.contracts.ContractState) @org.jetbrains.annotations.NotNull public static final net.corda.core.contracts.Amount withoutIssuer(net.corda.core.contracts.Amount) + public static final int MAX_ISSUER_REF_SIZE = 512 ## @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.contracts.TimeWindow extends java.lang.Object public () @@ -827,6 +828,9 @@ public final class net.corda.core.crypto.CryptoUtils extends java.lang.Object public final boolean verify(byte[]) @org.jetbrains.annotations.NotNull public final net.corda.core.crypto.DigitalSignature withoutKey() ## +public final class net.corda.core.crypto.DummySecureRandom extends java.security.SecureRandom + public static final net.corda.core.crypto.DummySecureRandom INSTANCE +## public abstract class net.corda.core.crypto.MerkleTree extends java.lang.Object @org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash getHash() public static final net.corda.core.crypto.MerkleTree$Companion Companion @@ -1140,11 +1144,14 @@ public static final class net.corda.core.flows.FinalityFlow$Companion extends ja @org.jetbrains.annotations.NotNull public net.corda.core.utilities.ProgressTracker childProgressTracker() public static final net.corda.core.flows.FinalityFlow$Companion$NOTARISING INSTANCE ## -@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException +@net.corda.core.serialization.CordaSerializable public class net.corda.core.flows.FlowException extends net.corda.core.CordaException implements net.corda.core.flows.IdentifiableException public () public (String) public (String, Throwable) public (Throwable) + @org.jetbrains.annotations.Nullable public Long getErrorId() + @org.jetbrains.annotations.Nullable public final Long getOriginalErrorId() + public final void setOriginalErrorId(Long) ## @net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.FlowInfo extends java.lang.Object public (int, String) @@ -1209,7 +1216,6 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object public final void checkFlowPermission(String, Map) @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.Nullable public final net.corda.core.flows.FlowStackSnapshot flowStackSnapshot() @org.jetbrains.annotations.Nullable public static final net.corda.core.flows.FlowLogic getCurrentTopLevel() - @kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.flows.FlowInfo getFlowInfo(net.corda.core.identity.Party) @org.jetbrains.annotations.NotNull public final org.slf4j.Logger getLogger() @org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party getOurIdentity() @org.jetbrains.annotations.NotNull public final net.corda.core.identity.PartyAndCertificate getOurIdentityAndCert() @@ -1219,16 +1225,18 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object @org.jetbrains.annotations.NotNull public final net.corda.core.internal.FlowStateMachine getStateMachine() @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.flows.FlowSession initiateFlow(net.corda.core.identity.Party) @co.paralleluniverse.fibers.Suspendable public final void persistFlowStackSnapshot() - @kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public net.corda.core.utilities.UntrustworthyData receive(Class, net.corda.core.identity.Party) @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public List receiveAll(Class, List) - @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAll(Map) + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public List receiveAll(Class, List, boolean) + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAllMap(Map) + @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public Map receiveAllMap(Map, boolean) public final void recordAuditEvent(String, String, Map) - @kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable public void send(net.corda.core.identity.Party, Object) - @kotlin.Deprecated @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public net.corda.core.utilities.UntrustworthyData sendAndReceive(Class, net.corda.core.identity.Party, Object) public final void setStateMachine(net.corda.core.internal.FlowStateMachine) @co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public static final void sleep(java.time.Duration) + @co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public static final void sleep(java.time.Duration, boolean) @co.paralleluniverse.fibers.Suspendable public Object subFlow(net.corda.core.flows.FlowLogic) @org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed track() + @org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed trackStepsTree() + @org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed trackStepsTreeIndex() @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.transactions.SignedTransaction waitForLedgerCommit(net.corda.core.crypto.SecureHash) @co.paralleluniverse.fibers.Suspendable @org.jetbrains.annotations.NotNull public final net.corda.core.transactions.SignedTransaction waitForLedgerCommit(net.corda.core.crypto.SecureHash, boolean) public static final net.corda.core.flows.FlowLogic$Companion Companion @@ -1236,6 +1244,7 @@ public abstract class net.corda.core.flows.FlowLogic extends java.lang.Object public static final class net.corda.core.flows.FlowLogic$Companion extends java.lang.Object @org.jetbrains.annotations.Nullable public final net.corda.core.flows.FlowLogic getCurrentTopLevel() @co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public final void sleep(java.time.Duration) + @co.paralleluniverse.fibers.Suspendable @kotlin.jvm.JvmStatic public final void sleep(java.time.Duration, boolean) ## @net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public interface net.corda.core.flows.FlowLogicRef ## @@ -1277,6 +1286,9 @@ public static final class net.corda.core.flows.FlowStackSnapshot$Frame extends j public int hashCode() @org.jetbrains.annotations.NotNull public String toString() ## +public interface net.corda.core.flows.IdentifiableException + @javax.annotation.Nullable public Long getErrorId() +## @net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.IllegalFlowLogicException extends java.lang.IllegalArgumentException public (Class, String) ## @@ -1434,9 +1446,10 @@ public final class net.corda.core.flows.TransactionParts extends java.lang.Objec public int hashCode() public String toString() ## -@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException - public (String) - public (String, Throwable) +@net.corda.core.serialization.CordaSerializable public final class net.corda.core.flows.UnexpectedFlowEndException extends net.corda.core.CordaRuntimeException implements net.corda.core.flows.IdentifiableException + public (String, Throwable, long) + @org.jetbrains.annotations.NotNull public Long getErrorId() + public final long getOriginalErrorId() ## @net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public abstract class net.corda.core.identity.AbstractParty extends java.lang.Object public (java.security.PublicKey) @@ -1593,18 +1606,27 @@ public final class net.corda.core.messaging.CordaRPCOpsKt extends java.lang.Obje @net.corda.core.DoNotImplement public interface net.corda.core.messaging.FlowProgressHandle extends net.corda.core.messaging.FlowHandle public abstract void close() @org.jetbrains.annotations.NotNull public abstract rx.Observable getProgress() + @org.jetbrains.annotations.Nullable public abstract net.corda.core.messaging.DataFeed getStepsTreeFeed() + @org.jetbrains.annotations.Nullable public abstract net.corda.core.messaging.DataFeed getStepsTreeIndexFeed() ## @net.corda.core.serialization.CordaSerializable @net.corda.core.DoNotImplement public final class net.corda.core.messaging.FlowProgressHandleImpl extends java.lang.Object implements net.corda.core.messaging.FlowProgressHandle public (net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable) + public (net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed) + public (net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed, net.corda.core.messaging.DataFeed) public void close() @org.jetbrains.annotations.NotNull public final net.corda.core.flows.StateMachineRunId component1() @org.jetbrains.annotations.NotNull public final net.corda.core.concurrent.CordaFuture component2() @org.jetbrains.annotations.NotNull public final rx.Observable component3() + @org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed component4() + @org.jetbrains.annotations.Nullable public final net.corda.core.messaging.DataFeed component5() @org.jetbrains.annotations.NotNull public final net.corda.core.messaging.FlowProgressHandleImpl copy(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable) + @org.jetbrains.annotations.NotNull public final net.corda.core.messaging.FlowProgressHandleImpl copy(net.corda.core.flows.StateMachineRunId, net.corda.core.concurrent.CordaFuture, rx.Observable, net.corda.core.messaging.DataFeed, net.corda.core.messaging.DataFeed) public boolean equals(Object) @org.jetbrains.annotations.NotNull public net.corda.core.flows.StateMachineRunId getId() @org.jetbrains.annotations.NotNull public rx.Observable getProgress() @org.jetbrains.annotations.NotNull public net.corda.core.concurrent.CordaFuture getReturnValue() + @org.jetbrains.annotations.Nullable public net.corda.core.messaging.DataFeed getStepsTreeFeed() + @org.jetbrains.annotations.Nullable public net.corda.core.messaging.DataFeed getStepsTreeIndexFeed() public int hashCode() public String toString() ## @@ -1692,6 +1714,7 @@ public @interface net.corda.core.messaging.RPCReturnsObservables public final int getPlatformVersion() public final long getSerial() public int hashCode() + @org.jetbrains.annotations.NotNull public final net.corda.core.identity.PartyAndCertificate identityAndCertFromX500Name(net.corda.core.identity.CordaX500Name) @org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party identityFromX500Name(net.corda.core.identity.CordaX500Name) public final boolean isLegalIdentity(net.corda.core.identity.Party) public String toString() @@ -1728,6 +1751,7 @@ public @interface net.corda.core.messaging.RPCReturnsObservables ## @net.corda.core.DoNotImplement public interface net.corda.core.node.StateLoader @org.jetbrains.annotations.NotNull public abstract net.corda.core.contracts.TransactionState loadState(net.corda.core.contracts.StateRef) + @org.jetbrains.annotations.NotNull public abstract Set loadStates(Set) ## public final class net.corda.core.node.StatesToRecord extends java.lang.Enum protected (String, int) @@ -1735,8 +1759,10 @@ public final class net.corda.core.node.StatesToRecord extends java.lang.Enum public static net.corda.core.node.StatesToRecord[] values() ## @net.corda.core.DoNotImplement public interface net.corda.core.node.services.AttachmentStorage + public abstract boolean hasAttachment(net.corda.core.crypto.SecureHash) @org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importAttachment(java.io.InputStream) @org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importAttachment(java.io.InputStream, String, String) + @org.jetbrains.annotations.NotNull public abstract net.corda.core.crypto.SecureHash importOrGetAttachment(java.io.InputStream) @org.jetbrains.annotations.Nullable public abstract net.corda.core.contracts.Attachment openAttachment(net.corda.core.crypto.SecureHash) @org.jetbrains.annotations.NotNull public abstract List queryAttachments(net.corda.core.node.services.vault.AttachmentQueryCriteria, net.corda.core.node.services.vault.AttachmentSort) ## @@ -1877,6 +1903,7 @@ public final class net.corda.core.node.services.TimeWindowChecker extends java.l @org.jetbrains.annotations.Nullable public abstract net.corda.core.transactions.SignedTransaction getTransaction(net.corda.core.crypto.SecureHash) @org.jetbrains.annotations.NotNull public abstract rx.Observable getUpdates() @org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.DataFeed track() + @org.jetbrains.annotations.NotNull public abstract net.corda.core.concurrent.CordaFuture trackTransaction(net.corda.core.crypto.SecureHash) ## @net.corda.core.DoNotImplement public interface net.corda.core.node.services.TransactionVerifierService @org.jetbrains.annotations.NotNull public abstract net.corda.core.concurrent.CordaFuture verify(net.corda.core.transactions.LedgerTransaction) @@ -1890,6 +1917,9 @@ public final class net.corda.core.node.services.TimeWindowChecker extends java.l @org.jetbrains.annotations.NotNull public final net.corda.core.crypto.TransactionSignature sign(net.corda.core.crypto.SecureHash) @org.jetbrains.annotations.NotNull public final net.corda.core.crypto.DigitalSignature$WithKey sign(byte[]) public final void validateTimeWindow(net.corda.core.contracts.TimeWindow) + public static final net.corda.core.node.services.TrustedAuthorityNotaryService$Companion Companion +## +public static final class net.corda.core.node.services.TrustedAuthorityNotaryService$Companion extends java.lang.Object ## @net.corda.core.serialization.CordaSerializable public final class net.corda.core.node.services.UniquenessException extends net.corda.core.CordaException public (net.corda.core.node.services.UniquenessProvider$Conflict) @@ -2609,6 +2639,7 @@ public final class net.corda.core.schemas.CommonSchemaV1 extends net.corda.core. public static final net.corda.core.schemas.CommonSchemaV1 INSTANCE ## @javax.persistence.MappedSuperclass @net.corda.core.serialization.CordaSerializable public static class net.corda.core.schemas.CommonSchemaV1$FungibleState extends net.corda.core.schemas.PersistentState + public () public (Set, net.corda.core.identity.AbstractParty, long, net.corda.core.identity.AbstractParty, byte[]) @org.jetbrains.annotations.NotNull public final net.corda.core.identity.AbstractParty getIssuer() @org.jetbrains.annotations.NotNull public final byte[] getIssuerRef() @@ -2622,6 +2653,7 @@ public final class net.corda.core.schemas.CommonSchemaV1 extends net.corda.core. public final void setQuantity(long) ## @javax.persistence.MappedSuperclass @net.corda.core.serialization.CordaSerializable public static class net.corda.core.schemas.CommonSchemaV1$LinearState extends net.corda.core.schemas.PersistentState + public () public (Set, String, UUID) public (net.corda.core.contracts.UniqueIdentifier, Set) @org.jetbrains.annotations.Nullable public final String getExternalId() @@ -3141,6 +3173,7 @@ public static final class net.corda.core.utilities.Id$Companion extends java.lan @kotlin.jvm.JvmStatic @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Id newInstance(Object, String, java.time.Instant) ## public final class net.corda.core.utilities.KotlinUtilsKt extends java.lang.Object + @org.jetbrains.annotations.NotNull public static final org.slf4j.Logger contextLogger(Object) public static final void debug(org.slf4j.Logger, kotlin.jvm.functions.Function0) public static final int exactAdd(int, int) public static final long exactAdd(long, long) @@ -3226,6 +3259,7 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends @net.corda.core.serialization.CordaSerializable public final class net.corda.core.utilities.ProgressTracker extends java.lang.Object public final void endWithError(Throwable) @org.jetbrains.annotations.NotNull public final List getAllSteps() + @org.jetbrains.annotations.NotNull public final List getAllStepsLabels() @org.jetbrains.annotations.NotNull public final rx.Observable getChanges() @org.jetbrains.annotations.Nullable public final net.corda.core.utilities.ProgressTracker getChildProgressTracker(net.corda.core.utilities.ProgressTracker$Step) @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step getCurrentStep() @@ -3234,12 +3268,16 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends @org.jetbrains.annotations.Nullable public final net.corda.core.utilities.ProgressTracker getParent() public final int getStepIndex() @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step[] getSteps() + @org.jetbrains.annotations.NotNull public final rx.Observable getStepsTreeChanges() + public final int getStepsTreeIndex() + @org.jetbrains.annotations.NotNull public final rx.Observable getStepsTreeIndexChanges() @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker getTopLevelTracker() @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker$Step nextStep() public final void setChildProgressTracker(net.corda.core.utilities.ProgressTracker$Step, net.corda.core.utilities.ProgressTracker) public final void setCurrentStep(net.corda.core.utilities.ProgressTracker$Step) ## @net.corda.core.serialization.CordaSerializable public abstract static class net.corda.core.utilities.ProgressTracker$Change extends java.lang.Object + @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ProgressTracker getProgressTracker() ## @net.corda.core.serialization.CordaSerializable public static final class net.corda.core.utilities.ProgressTracker$Change$Position extends net.corda.core.utilities.ProgressTracker$Change public (net.corda.core.utilities.ProgressTracker, net.corda.core.utilities.ProgressTracker$Step) @@ -3292,6 +3330,10 @@ public static final class net.corda.core.utilities.OpaqueBytes$Companion extends public interface net.corda.core.utilities.PropertyDelegate public abstract Object getValue(Object, kotlin.reflect.KProperty) ## +public final class net.corda.core.utilities.SgxSupport extends java.lang.Object + public static final boolean isInsideEnclave() + public static final net.corda.core.utilities.SgxSupport INSTANCE +## @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.utilities.Try extends java.lang.Object @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Try combine(net.corda.core.utilities.Try, kotlin.jvm.functions.Function2) @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.Try flatMap(kotlin.jvm.functions.Function1) @@ -3338,6 +3380,7 @@ public static interface net.corda.core.utilities.UntrustworthyData$Validator ext @co.paralleluniverse.fibers.Suspendable public abstract Object validate(Object) ## public final class net.corda.core.utilities.UntrustworthyDataKt extends java.lang.Object + @org.jetbrains.annotations.NotNull public static final net.corda.core.utilities.UntrustworthyData checkPayloadIs(net.corda.core.serialization.SerializedBytes, Class) public static final Object unwrap(net.corda.core.utilities.UntrustworthyData, kotlin.jvm.functions.Function1) ## public final class net.corda.core.utilities.UuidGenerator extends java.lang.Object diff --git a/.idea/compiler.xml b/.idea/compiler.xml index 898d3fc2e0..d72e565e78 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -57,6 +57,8 @@ + + @@ -170,4 +172,4 @@ - \ No newline at end of file + diff --git a/core/src/main/java/net/corda/core/flows/IdentifiableException.java b/core/src/main/java/net/corda/core/flows/IdentifiableException.java new file mode 100644 index 0000000000..d1d32a97f3 --- /dev/null +++ b/core/src/main/java/net/corda/core/flows/IdentifiableException.java @@ -0,0 +1,16 @@ +package net.corda.core.flows; + +import javax.annotation.Nullable; + +/** + * An exception that may be identified with an ID. If an exception originates in a counter-flow this ID will be + * propagated. This allows correlation of error conditions across different flows. + */ +public interface IdentifiableException { + /** + * @return the ID of the error, or null if the error doesn't have it set (yet). + */ + default @Nullable Long getErrorId() { + return null; + } +} diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt index 33251020f8..ac0fbdaa23 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt @@ -7,16 +7,27 @@ import net.corda.core.CordaRuntimeException /** * Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end. * The exception will propagate to all counterparty flows and will be thrown on their end the next time they wait on a - * [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already ended, - * will not receive the exception (if this is required then have them wait for a confirmation message). + * [FlowSession.receive] or [FlowSession.sendAndReceive]. Any flow which no longer needs to do a receive, or has already + * ended, will not receive the exception (if this is required then have them wait for a confirmation message). + * + * If the *rethrown* [FlowException] is uncaught in counterparty flows and propagation triggers then the exception is + * downgraded to an [UnexpectedFlowEndException]. This means only immediate counterparty flows will receive information + * about what the exception was. * * [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service. * It is recommended a [FlowLogic] document the [FlowException] types it can throw. + * + * @property originalErrorId the ID backing [getErrorId]. If null it will be set dynamically by the flow framework when + * the exception is handled. This ID is propagated to counterparty flows, even when the [FlowException] is + * downgraded to an [UnexpectedFlowEndException]. This is so the error conditions may be correlated later on. */ -open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) { +open class FlowException(message: String?, cause: Throwable?) : + CordaException(message, cause), IdentifiableException { constructor(message: String?) : this(message, null) constructor(cause: Throwable?) : this(cause?.toString(), cause) constructor() : this(null, null) + var originalErrorId: Long? = null + override fun getErrorId(): Long? = originalErrorId } // DOCEND 1 @@ -25,6 +36,7 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m * that we were not expecting), or the other side had an internal error, or the other side terminated when we * were waiting for a response. */ -class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { - constructor(msg: String) : this(msg, null) -} \ No newline at end of file +class UnexpectedFlowEndException(message: String, cause: Throwable?, val originalErrorId: Long) : + CordaRuntimeException(message, cause), IdentifiableException { + override fun getErrorId(): Long = originalErrorId +} diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index a10a3d3fe2..d7c2a50c1c 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -5,6 +5,7 @@ import co.paralleluniverse.strands.Strand import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.abbreviate import net.corda.core.internal.uncheckedCast @@ -12,10 +13,10 @@ import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.serialize import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.ProgressTracker -import net.corda.core.utilities.UntrustworthyData -import net.corda.core.utilities.debug +import net.corda.core.utilities.* import org.slf4j.Logger import java.time.Duration import java.time.Instant @@ -75,12 +76,19 @@ abstract class FlowLogic { */ @Suspendable @JvmStatic + @JvmOverloads @Throws(FlowException::class) - fun sleep(duration: Duration) { + fun sleep(duration: Duration, maySkipCheckpoint: Boolean = false) { if (duration.compareTo(Duration.ofMinutes(5)) > 0) { throw FlowException("Attempt to sleep for longer than 5 minutes is not supported. Consider using SchedulableState.") } - (Strand.currentStrand() as? FlowStateMachine<*>)?.sleepUntil(Instant.now() + duration) ?: Strand.sleep(duration.toMillis()) + val fiber = (Strand.currentStrand() as? FlowStateMachine<*>) + if (fiber == null) { + Strand.sleep(duration.toMillis()) + } else { + val request = FlowIORequest.Sleep(wakeUpAfter = fiber.serviceHub.clock.instant() + duration) + fiber.suspend(request, maySkipCheckpoint = maySkipCheckpoint) + } } } @@ -92,7 +100,7 @@ abstract class FlowLogic { /** * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is - * only available once the flow has started, which means it cannnot be accessed in the constructor. Either + * only available once the flow has started, which means it cannot be accessed in the constructor. Either * access this lazily or from inside [call]. */ val serviceHub: ServiceHub get() = stateMachine.serviceHub @@ -102,7 +110,7 @@ abstract class FlowLogic { * that this function does not communicate in itself, the counter-flow will be kicked off by the first send/receive. */ @Suspendable - fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions) + fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party) /** * Specifies the identity, with certificate, to use for this flow. This will be one of the multiple identities that @@ -112,7 +120,10 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentityAndCert: PartyAndCertificate get() = stateMachine.ourIdentityAndCert + val ourIdentityAndCert: PartyAndCertificate get() { + return serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == stateMachine.ourIdentity } + ?: throw IllegalStateException("Identity specified by ${stateMachine.id} (${stateMachine.ourIdentity}) is not one of ours!") + } /** * Specifies the identity to use for this flow. This will be one of the multiple identities that belong to this node. @@ -122,102 +133,23 @@ abstract class FlowLogic { * Note: The current implementation returns the single identity of the node. This will change once multiple identities * is implemented. */ - val ourIdentity: Party get() = ourIdentityAndCert.party - /** - * Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it - * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. - * - * This method can be called before any send or receive has been done with [otherParty]. In such a case this will force - * them to start their flow. - */ - @Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING) - @Suspendable - fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions, maySkipCheckpoint = false) - - /** - * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response - * is received, which must be of the given [R] type. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - * - * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to - * use this when you expect to do a message swap than do use [send] and then [receive] in turn. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) - inline fun sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData { - return sendAndReceive(R::class.java, otherParty, payload) - } - - /** - * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response - * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data - * should not be trusted until it's been thoroughly verified for consistency and that all expectations are - * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. - * - * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to - * use this when you expect to do a message swap than do use [send] and then [receive] in turn. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) - @Suspendable - open fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, retrySend = false, maySkipCheckpoint = false) - } - - /** - * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. - * - * Note that this method should NOT be used for regular party-to-party communication, use [sendAndReceive] instead. - * It is only intended for the case where the [otherParty] is running a distributed service with an idempotent - * flow which only accepts a single request and sends back a single response – e.g. a notary or certain types of - * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered - * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. - */ - @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) - internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) - } + val ourIdentity: Party get() = stateMachine.ourIdentity @Suspendable internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) + val request = FlowIORequest.SendAndReceive( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = true + ) + return stateMachine.suspend(request, maySkipCheckpoint = true)[this]!!.checkPayloadIs(receiveType) } @Suspendable internal inline fun FlowSession.sendAndReceiveWithRetry(payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, counterparty, payload, flowUsedForSessions, retrySend = true, maySkipCheckpoint = false) + return sendAndReceiveWithRetry(R::class.java, payload) } - /** - * Suspends until the specified [otherParty] sends us a message of type [R]. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - */ - @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) - inline fun receive(otherParty: Party): UntrustworthyData = receive(R::class.java, otherParty) - - /** - * Suspends until the specified [otherParty] sends us a message of type [receiveType]. - * - * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly - * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly - * corrupted data in order to exploit your code. - * - * @return an [UntrustworthyData] wrapper around the received object. - */ - @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) - @Suspendable - open fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { - return stateMachine.receive(receiveType, otherParty, flowUsedForSessions, maySkipCheckpoint = false) - } /** Suspends until a message has been received for each session in the specified [sessions]. * @@ -230,8 +162,14 @@ abstract class FlowLogic { * @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them. */ @Suspendable - open fun receiveAll(sessions: Map>): Map> { - return stateMachine.receiveAll(sessions, this) + @JvmOverloads + open fun receiveAllMap(sessions: Map>, maySkipCheckpoint: Boolean = false): Map> { + enforceNoPrimitiveInReceive(sessions.values) + val replies = stateMachine.suspend( + ioRequest = FlowIORequest.Receive(sessions.keys.toNonEmptySet()), + maySkipCheckpoint = maySkipCheckpoint + ) + return replies.mapValues { (session, payload) -> payload.checkPayloadIs(sessions[session]!!) } } /** @@ -246,22 +184,11 @@ abstract class FlowLogic { * @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. */ @Suspendable - open fun receiveAll(receiveType: Class, sessions: List): List> { + @JvmOverloads + open fun receiveAll(receiveType: Class, sessions: List, maySkipCheckpoint: Boolean = false): List> { + enforceNoPrimitiveInReceive(listOf(receiveType)) enforceNoDuplicates(sessions) - return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions))) - } - - /** - * Queues the given [payload] for sending to the [otherParty] and continues without suspending. - * - * Note that the other party may receive the message at some arbitrary later point or not at all: if [otherParty] - * is offline then message delivery will be retried until it comes back or until the message is older than the - * network's event horizon time. - */ - @Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING) - @Suspendable - open fun send(otherParty: Party, payload: Any) { - stateMachine.send(otherParty, payload, flowUsedForSessions, maySkipCheckpoint = false) + return castMapValuesToKnownType(receiveAllMap(associateSessionsToReceiveType(receiveType, sessions), maySkipCheckpoint)) } /** @@ -281,11 +208,8 @@ abstract class FlowLogic { open fun subFlow(subLogic: FlowLogic): R { subLogic.stateMachine = stateMachine maybeWireUpProgressTracking(subLogic) - if (!subLogic.javaClass.isAnnotationPresent(InitiatingFlow::class.java)) { - subLogic.flowUsedForSessions = flowUsedForSessions - } logger.debug { "Calling subflow: $subLogic" } - val result = subLogic.call() + val result = stateMachine.subFlow(subLogic) logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" } // It's easy to forget this when writing flows so we just step it to the DONE state when it completes. subLogic.progressTracker?.currentStep = ProgressTracker.DONE @@ -382,7 +306,8 @@ abstract class FlowLogic { @Suspendable @JvmOverloads fun waitForLedgerCommit(hash: SecureHash, maySkipCheckpoint: Boolean = false): SignedTransaction { - return stateMachine.waitForLedgerCommit(hash, this, maySkipCheckpoint = maySkipCheckpoint) + val request = FlowIORequest.WaitForLedgerCommit(hash) + return stateMachine.suspend(request, maySkipCheckpoint = maySkipCheckpoint) } /** @@ -423,10 +348,6 @@ abstract class FlowLogic { _stateMachine = value } - // This is the flow used for managing sessions. It defaults to the current flow but if this is an inlined sub-flow - // then it will point to the flow it's been inlined to. - private var flowUsedForSessions: FlowLogic<*> = this - private fun maybeWireUpProgressTracking(subLogic: FlowLogic<*>) { val ours = progressTracker val theirs = subLogic.progressTracker @@ -443,6 +364,11 @@ abstract class FlowLogic { require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." } } + private fun enforceNoPrimitiveInReceive(types: Collection>) { + val primitiveTypes = types.filter { it.isPrimitive } + require(primitiveTypes.isEmpty()) { "Cannot receive primitive type(s) $primitiveTypes" } + } + private fun associateSessionsToReceiveType(receiveType: Class, sessions: List): Map> { return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType }) } diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt new file mode 100644 index 0000000000..ed9052c36b --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt @@ -0,0 +1,85 @@ +package net.corda.core.internal + +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowSession +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.NonEmptySet +import java.time.Instant + +/** + * A [FlowIORequest] represents an IO request of a flow when it suspends. It is persisted in checkpoints. + */ +sealed class FlowIORequest { + /** + * Send messages to sessions. + * + * @property sessionToMessage a map from session to message-to-be-sent. + * @property shouldRetrySend specifies whether the send should be retried. + */ + data class Send( + val sessionToMessage: Map>, + val shouldRetrySend: Boolean + ) : FlowIORequest() { + override fun toString() = "Send(" + + "sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " + + "shouldRetrySend=$shouldRetrySend" + + ")" + } + + /** + * Receive messages from sessions. + * + * @property sessions the sessions to receive messages from. + * @return a map from session to received message. + */ + data class Receive( + val sessions: NonEmptySet + ) : FlowIORequest>>() + + /** + * Send and receive messages from the specified sessions. + * + * @property sessionToMessage a map from session to message-to-be-sent. The keys also specify which sessions to + * receive from. + * @property shouldRetrySend specifies whether the send should be retried. + * @return a map from session to received message. + */ + data class SendAndReceive( + val sessionToMessage: Map>, + val shouldRetrySend: Boolean + ) : FlowIORequest>>() { + override fun toString() = "SendAndReceive(${sessionToMessage.mapValues { (key, value) -> + "$key=${value.hash}" }}, shouldRetrySend=$shouldRetrySend)" + } + + /** + * Wait for a transaction to be committed to the database. + * + * @property hash the hash of the transaction. + * @return the committed transaction. + */ + data class WaitForLedgerCommit(val hash: SecureHash) : FlowIORequest() + + /** + * Get the FlowInfo of the specified sessions. + * + * @property sessions the sessions to get the FlowInfo of. + * @return a map from session to FlowInfo. + */ + data class GetFlowInfo(val sessions: NonEmptySet) : FlowIORequest>() + + /** + * Suspend the flow until the specified time. + * + * @property wakeUpAfter the time to sleep until. + */ + data class Sleep(val wakeUpAfter: Instant) : FlowIORequest() + + /** + * Suspend the flow until all Initiating sessions are confirmed. + */ + object WaitForSessionConfirmations : FlowIORequest() +} + diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index 3d9af23d71..7ea31b7bb7 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -1,64 +1,42 @@ package net.corda.core.internal import co.paralleluniverse.fibers.Suspendable +import net.corda.core.DoNotImplement import net.corda.core.concurrent.CordaFuture -import net.corda.core.crypto.SecureHash import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.context.InvocationContext import net.corda.core.node.ServiceHub -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.UntrustworthyData import org.slf4j.Logger -import java.time.Instant /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ -interface FlowStateMachine { +@DoNotImplement +interface FlowStateMachine { @Suspendable - fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo + fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): SUSPENDRETURN @Suspendable - fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession - - @Suspendable - fun sendAndReceive(receiveType: Class, - otherParty: Party, - payload: Any, - sessionFlow: FlowLogic<*>, - retrySend: Boolean, - maySkipCheckpoint: Boolean): UntrustworthyData - - @Suspendable - fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): UntrustworthyData - - @Suspendable - fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) - - @Suspendable - fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction - - @Suspendable - fun sleepUntil(until: Instant) + fun initiateFlow(party: Party): FlowSession fun checkFlowPermission(permissionName: String, extraAuditData: Map) fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) + @Suspendable + fun subFlow(subFlow: FlowLogic): SUBFLOWRETURN + @Suspendable fun flowStackSnapshot(flowClass: Class>): FlowStackSnapshot? @Suspendable fun persistFlowStackSnapshot(flowClass: Class>) - val logic: FlowLogic + val logic: FlowLogic val serviceHub: ServiceHub val logger: Logger val id: StateMachineRunId - val resultFuture: CordaFuture + val resultFuture: CordaFuture val context: InvocationContext - val ourIdentityAndCert: PartyAndCertificate - - @Suspendable - fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> + val ourIdentity: Party } diff --git a/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt b/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt index 9b6b713ed2..b04c96729f 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/TransactionStorage.kt @@ -1,6 +1,7 @@ package net.corda.core.node.services import net.corda.core.DoNotImplement +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.messaging.DataFeed import net.corda.core.transactions.SignedTransaction @@ -26,4 +27,9 @@ interface TransactionStorage { * Returns all currently stored transactions and further fresh ones. */ fun track(): DataFeed, SignedTransaction> + + /** + * Returns a future that completes with the transaction corresponding to [id] once it has been committed + */ + fun trackTransaction(id: SecureHash): CordaFuture } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt b/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt index 272b5ec200..6e2c3f412d 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/UntrustworthyData.kt @@ -2,6 +2,9 @@ package net.corda.core.utilities import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.FlowException +import net.corda.core.internal.castIfPossible +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes import java.io.Serializable /** @@ -29,3 +32,15 @@ class UntrustworthyData(@PublishedApi internal val fromUntrustedWorld: T) } inline fun UntrustworthyData.unwrap(validator: (T) -> R): R = validator(fromUntrustedWorld) + +fun SerializedBytes.checkPayloadIs(type: Class): UntrustworthyData { + val payloadData: T = try { + val serializer = SerializationDefaults.SERIALIZATION_FACTORY + serializer.deserialize(this, type, SerializationDefaults.P2P_CONTEXT) + } catch (ex: Exception) { + throw IllegalArgumentException("Payload invalid", ex) + } + return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: + throw IllegalArgumentException("We were expecting a ${type.name} but we instead got a " + + "${payloadData.javaClass.name} (${payloadData})") +} \ No newline at end of file diff --git a/core/src/smoke-test/kotlin/net/corda/core/cordapp/CordappSmokeTest.kt b/core/src/smoke-test/kotlin/net/corda/core/cordapp/CordappSmokeTest.kt index 2f826b5f54..44db07a4ed 100644 --- a/core/src/smoke-test/kotlin/net/corda/core/cordapp/CordappSmokeTest.kt +++ b/core/src/smoke-test/kotlin/net/corda/core/cordapp/CordappSmokeTest.kt @@ -85,7 +85,7 @@ class CordappSmokeTest { class SendBackInitiatorFlowContext(private val otherPartySession: FlowSession) : FlowLogic() { @Suspendable override fun call() { - // An initiated flow calling getFlowContext on its initiator will get the context from the session-init + // An initiated flow calling getFlowInfo on its initiator will get the context from the session-init val sessionInitContext = otherPartySession.getCounterpartyFlowInfo() otherPartySession.send(sessionInitContext) } diff --git a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java index 1b69fc5a6c..6d998b4917 100644 --- a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java +++ b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java @@ -14,9 +14,9 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import static net.corda.testing.CoreTestUtils.singleIdentity; +import static net.corda.testing.NodeTestUtils.startFlow; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.junit.Assert.fail; -import static net.corda.testing.NodeTestUtils.startFlow; public class FlowsInJavaTest { private final MockNetwork mockNet = new MockNetwork(); @@ -62,9 +62,8 @@ public class FlowsInJavaTest { fail("ExecutionException should have been thrown"); } catch (ExecutionException e) { assertThat(e.getCause()) - .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("primitive") - .hasMessageContaining(receiveType.getName()); + .hasMessageContaining(Primitives.unwrap(receiveType).getName()); } } @@ -102,6 +101,18 @@ public class FlowsInJavaTest { } } + @InitiatedBy(PrimitiveReceiveFlow.class) + private static class PrimitiveSendFlow extends FlowLogic { + public PrimitiveSendFlow(FlowSession session) { + } + + @Suspendable + @Override + public Void call() throws FlowException { + return null; + } + } + @InitiatingFlow private static class PrimitiveReceiveFlow extends FlowLogic { private final Party otherParty; diff --git a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt index 4d2d12335d..9b004f4e6f 100644 --- a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt +++ b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt @@ -79,7 +79,7 @@ infix fun KClass.from(session: FlowSession): Pair.receiveAll(session: Pair>, vararg sessions: Pair>): Map> { val allSessions = arrayOf(session, *sessions) allSessions.enforceNoDuplicates() - return receiveAll(mapOf(*allSessions)) + return receiveAllMap(mapOf(*allSessions)) } /** diff --git a/docs/source/api-flows.rst b/docs/source/api-flows.rst index 4a0e284621..5e7a240a25 100644 --- a/docs/source/api-flows.rst +++ b/docs/source/api-flows.rst @@ -409,68 +409,6 @@ Our side of the flow must mirror these calls. We could do this as follows: :end-before: DOCEND 08 :dedent: 12 -Why sessions? -^^^^^^^^^^^^^ - -Before ``FlowSession`` s were introduced the send/receive API looked a bit different. They were functions on -``FlowLogic`` and took the address ``Party`` as argument. The platform internally maintained a mapping from ``Party`` to -session, hiding sessions from the user completely. - -Although this is a convenient API it introduces subtle issues where a message that was originally meant for a specific -session may end up in another. - -Consider the following contrived example using the old ``Party`` based API: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt - :language: kotlin - :start-after: DOCSTART LaunchSpaceshipFlow - :end-before: DOCEND LaunchSpaceshipFlow - -The intention of the flows is very clear: LaunchSpaceshipFlow asks the president whether a spaceship should be launched. -It is expecting a boolean reply. The president in return first tells the secretary that they need coffee, which is also -communicated with a boolean. Afterwards the president replies to the launcher that they don't want to launch. - -However the above can go horribly wrong when the ``launcher`` happens to be the same party ``getSecretary`` returns. In -this case the boolean meant for the secretary will be received by the launcher! - -This indicates that ``Party`` is not a good identifier for the communication sequence, and indeed the ``Party`` based -API may introduce ways for an attacker to fish for information and even trigger unintended control flow like in the -above case. - -Hence we introduced ``FlowSession``, which identifies the communication sequence. With ``FlowSession`` s the above set -of flows would look like this: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt - :language: kotlin - :start-after: DOCSTART LaunchSpaceshipFlowCorrect - :end-before: DOCEND LaunchSpaceshipFlowCorrect - -Note how the president is now explicit about which session it wants to send to. - -Porting from the old Party-based API -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the old API the first ``send`` or ``receive`` to a ``Party`` was the one kicking off the counter-flow. This is now -explicit in the ``initiateFlow`` function call. To port existing code: - -.. container:: codeset - - .. literalinclude:: ../../docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt - :language: kotlin - :start-after: DOCSTART FlowSession porting - :end-before: DOCEND FlowSession porting - :dedent: 8 - - .. literalinclude:: ../../docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java - :language: java - :start-after: DOCSTART FlowSession porting - :end-before: DOCEND FlowSession porting - :dedent: 12 - Subflows -------- diff --git a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java index ec928e9997..6b24146635 100644 --- a/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java +++ b/docs/source/example-code/src/main/java/net/corda/docs/FlowCookbookJava.java @@ -575,13 +575,6 @@ public class FlowCookbookJava { SignedTransaction notarisedTx2 = subFlow(new FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())); // DOCEND 10 - // DOCSTART FlowSession porting - send(regulator, new Object()); // Old API - // becomes - FlowSession session = initiateFlow(regulator); - session.send(new Object()); - // DOCEND FlowSession porting - return null; } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt index 64b3c0845a..3718ec1bfe 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/FlowCookbook.kt @@ -553,13 +553,6 @@ class InitiatorFlow(val arg1: Boolean, val arg2: Int, private val counterparty: val additionalParties: Set = setOf(regulator) val notarisedTx2: SignedTransaction = subFlow(FinalityFlow(fullySignedTx, additionalParties, FINALISATION.childProgressTracker())) // DOCEND 10 - - // DOCSTART FlowSession porting - send(regulator, Any()) // Old API - // becomes - val session = initiateFlow(regulator) - session.send(Any()) - // DOCEND FlowSession porting } } diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt deleted file mode 100644 index e6826fa213..0000000000 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/LaunchSpaceshipFlow.kt +++ /dev/null @@ -1,99 +0,0 @@ -package net.corda.docs - -import co.paralleluniverse.fibers.Suspendable -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSession -import net.corda.core.flows.InitiatedBy -import net.corda.core.flows.InitiatingFlow -import net.corda.core.identity.Party -import net.corda.core.utilities.unwrap - -// DOCSTART LaunchSpaceshipFlow -@InitiatingFlow -class LaunchSpaceshipFlow : FlowLogic() { - @Suspendable - override fun call() { - val shouldLaunchSpaceship = receive(getPresident()).unwrap { it } - if (shouldLaunchSpaceship) { - launchSpaceship() - } - } - - fun launchSpaceship() { - } - - fun getPresident(): Party { - TODO() - } -} - -@InitiatedBy(LaunchSpaceshipFlow::class) -@InitiatingFlow -class PresidentSpaceshipFlow(val launcher: Party) : FlowLogic() { - @Suspendable - override fun call() { - val needCoffee = true - send(getSecretary(), needCoffee) - val shouldLaunchSpaceship = false - send(launcher, shouldLaunchSpaceship) - } - - fun getSecretary(): Party { - TODO() - } -} - -@InitiatedBy(PresidentSpaceshipFlow::class) -class SecretaryFlow(val president: Party) : FlowLogic() { - @Suspendable - override fun call() { - // ignore - } -} -// DOCEND LaunchSpaceshipFlow - -// DOCSTART LaunchSpaceshipFlowCorrect -@InitiatingFlow -class LaunchSpaceshipFlowCorrect : FlowLogic() { - @Suspendable - override fun call() { - val presidentSession = initiateFlow(getPresident()) - val shouldLaunchSpaceship = presidentSession.receive().unwrap { it } - if (shouldLaunchSpaceship) { - launchSpaceship() - } - } - - fun launchSpaceship() { - } - - fun getPresident(): Party { - TODO() - } -} - -@InitiatedBy(LaunchSpaceshipFlowCorrect::class) -@InitiatingFlow -class PresidentSpaceshipFlowCorrect(val launcherSession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - val needCoffee = true - val secretarySession = initiateFlow(getSecretary()) - secretarySession.send(needCoffee) - val shouldLaunchSpaceship = false - launcherSession.send(shouldLaunchSpaceship) - } - - fun getSecretary(): Party { - TODO() - } -} - -@InitiatedBy(PresidentSpaceshipFlowCorrect::class) -class SecretaryFlowCorrect(val presidentSession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - // ignore - } -} -// DOCEND LaunchSpaceshipFlowCorrect diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt index ba8802dceb..595db4bb6e 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt @@ -10,11 +10,13 @@ import net.corda.core.identity.Party import net.corda.core.messaging.MessageRecipients import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.node.internal.StartedNode import net.corda.node.services.messaging.Message -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage +import net.corda.node.services.statemachine.ExistingSessionMessage import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MessagingServiceSpy import net.corda.testing.node.MockNetwork @@ -84,11 +86,11 @@ class TutorialMockNetwork { nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { - val messageData = message.data.deserialize() - - if (messageData is SessionData && messageData.payload.deserialize() == 1) { - val alteredMessageData = SessionData(messageData.recipientSessionId, 99.serialize()).serialize().bytes - messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topicSession, alteredMessageData, message.uniqueMessageId), target, retryId) + val messageData = message.data.deserialize() as? ExistingSessionMessage + val payload = messageData?.payload + if (payload is DataSessionMessage && payload.payload.deserialize() == 1) { + val alteredMessageData = messageData.copy(payload = payload.copy(99.serialize())).serialize().bytes + messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData), message.uniqueMessageId), target, retryId) } else { messagingService.send(message, target, retryId) } diff --git a/docs/source/versioning.rst b/docs/source/versioning.rst index 9f5967e1ef..f53e9fb852 100644 --- a/docs/source/versioning.rst +++ b/docs/source/versioning.rst @@ -38,8 +38,8 @@ or if the semantics of a particular receive changes. The ``InitiatingFlow`` annotation (see :doc:`flow-state-machine` for more information on the flow annotations) has a ``version`` property, which if not specified defaults to 1. This flow version is included in the flow session handshake and exposed -to both parties in the communication via ``FlowLogic.getFlowContext``. This takes in a ``Party`` and will return a -``FlowContext`` object which describes the flow running on the other side. In particular it has the ``flowVersion`` property +to both parties in the communication via ``FlowLogic.getFlowInfo``. This takes in a ``Party`` and will return a +``FlowInfo`` object which describes the flow running on the other side. In particular it has the ``flowVersion`` property which can be used to programmatically evolve flows across versions. .. container:: codeset @@ -48,7 +48,7 @@ which can be used to programmatically evolve flows across versions. @Suspendable override fun call() { - val flowVersionOfOtherParty = getFlowContext(otherParty).flowVersion + val flowVersionOfOtherParty = getFlowInfo(otherParty).flowVersion val receivedString = if (flowVersionOfOtherParty == 1) { receive(otherParty).unwrap { it.toString() } } else { @@ -63,7 +63,7 @@ running the older flow (or rather older CorDapps containing the older flow). .. warning:: It's important that ``InitiatingFlow.version`` be incremented each time the flow protocol changes in an incompatible way. -``FlowContext`` also has ``appName`` which is the name of the CorDapp hosting the flow. This can be used to determine +``FlowInfo`` also has ``appName`` which is the name of the CorDapp hosting the flow. This can be used to determine implementation details of the CorDapp. See :doc:`cordapp-build-systems` for more information on the CorDapp filename. .. note:: Currently changing any of the properties of a ``CordaSerializable`` type is also backwards incompatible and diff --git a/experimental/flow-hook/build.gradle b/experimental/flow-hook/build.gradle new file mode 100644 index 0000000000..8ae4292898 --- /dev/null +++ b/experimental/flow-hook/build.gradle @@ -0,0 +1,53 @@ +buildscript { + // For sharing constants between builds + Properties constants = new Properties() + file("$projectDir/../../constants.properties").withInputStream { constants.load(it) } + + ext.kotlin_version = constants.getProperty("kotlinVersion") + ext.javaassist_version = "3.12.1.GA" + + repositories { + mavenLocal() + mavenCentral() + jcenter() + } + + dependencies { + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + } +} + +repositories { + mavenLocal() + mavenCentral() + jcenter() +} + +apply plugin: 'kotlin' +apply plugin: 'kotlin-kapt' +apply plugin: 'idea' + +description 'A javaagent to allow hooking into Kryo' + +dependencies { + compile project(':node') + compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" + compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version" + compile "javassist:javassist:$javaassist_version" + compile "com.esotericsoftware:kryo:4.0.0" + compile "co.paralleluniverse:quasar-core:$quasar_version:jdk8" +} + +jar { + archiveName = "${project.name}.jar" + manifest { + attributes( + 'Premain-Class': 'net.corda.flowhook.FlowHookAgent', + 'Can-Redefine-Classes': 'true', + 'Can-Retransform-Classes': 'true', + 'Can-Set-Native-Method-Prefix': 'true', + 'Implementation-Title': "FlowHook", + 'Implementation-Version': rootProject.version + ) + } +} diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt new file mode 100644 index 0000000000..46889f41fa --- /dev/null +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt @@ -0,0 +1,209 @@ +package net.corda.flowhook + +import co.paralleluniverse.fibers.Fiber +import net.corda.core.internal.uncheckedCast +import net.corda.core.utilities.contextLogger +import net.corda.nodeapi.internal.persistence.DatabaseTransaction +import java.sql.Connection +import java.time.Instant +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.concurrent.thread + +/** + * This is a debugging helper class that dumps the map of Fiber->DB connection, or more precisely, the + * Fiber->(DB tx -> DB connection) map, as there may be multiple transactions per fiber. + */ + +data class MonitorEvent(val type: MonitorEventType, val keys: List, val extra: Any? = null) + +data class FullMonitorEvent(val timestamp: Instant, val trace: List, val event: MonitorEvent) { + override fun toString() = event.toString() +} + +enum class MonitorEventType { + TransactionCreated, + ConnectionRequested, + ConnectionAcquired, + ConnectionReleased, + + FiberStarted, + FiberParking, + FiberException, + FiberResumed, + FiberEnded, + + ExecuteTransition +} + +object FiberMonitor { + private val log = contextLogger() + private val jobQueue = LinkedBlockingQueue() + private val started = AtomicBoolean(false) + private var trackerThread: Thread? = null + + val correlator = MonitorEventCorrelator() + + sealed class Job { + data class NewEvent(val event: FullMonitorEvent) : Job() + object Finish : Job() + } + + fun newEvent(event: MonitorEvent) { + if (trackerThread != null) { + jobQueue.add(Job.NewEvent(FullMonitorEvent(Instant.now(), Exception().stackTrace.toList(), event))) + } + } + + fun start() { + if (started.compareAndSet(false, true)) { + require(trackerThread == null) + trackerThread = thread(name = "Fiber monitor", isDaemon = true) { + while (true) { + val job = jobQueue.poll(1, TimeUnit.SECONDS) + when (job) { + is Job.NewEvent -> processEvent(job) + Job.Finish -> return@thread + } + } + } + } + } + + private fun processEvent(job: Job.NewEvent) { + correlator.addEvent(job.event) + checkLeakedTransactions(job.event.event) + checkLeakedConnections(job.event.event) + } + + inline fun R.getField(name: String): A { + val field = R::class.java.getDeclaredField(name) + field.isAccessible = true + return uncheckedCast(field.get(this)) + } + + fun Any.getFieldFromObject(name: String): A { + val field = javaClass.getDeclaredField(name) + field.isAccessible = true + return uncheckedCast(field.get(this)) + } + + fun getThreadLocalMapEntryValues(locals: Any): List { + val table: Array = locals.getFieldFromObject("table") + return table.mapNotNull { it?.getFieldFromObject("value") } + } + + fun getStashedThreadLocals(fiber: Fiber<*>): List { + val fiberLocals: Any = fiber.getField("fiberLocals") + val inheritableFiberLocals: Any = fiber.getField("inheritableFiberLocals") + return getThreadLocalMapEntryValues(fiberLocals) + getThreadLocalMapEntryValues(inheritableFiberLocals) + } + + fun getTransactionStack(transaction: DatabaseTransaction): List { + val transactions = ArrayList() + var currentTransaction: DatabaseTransaction? = transaction + while (currentTransaction != null) { + transactions.add(currentTransaction) + currentTransaction = currentTransaction.outerTransaction + } + return transactions + } + + private fun checkLeakedTransactions(event: MonitorEvent) { + if (event.type == MonitorEventType.FiberParking) { + val fiber = event.keys.mapNotNull { it as? Fiber<*> }.first() + val threadLocals = getStashedThreadLocals(fiber) + val transactions = threadLocals.mapNotNull { it as? DatabaseTransaction }.flatMap { getTransactionStack(it) } + val leakedTransactions = transactions.filter { it.connectionCreated && !it.connection.isClosed } + if (leakedTransactions.isNotEmpty()) { + log.warn("Leaked open database transactions on yield $leakedTransactions") + } + } + } + + private fun checkLeakedConnections(event: MonitorEvent) { + if (event.type == MonitorEventType.FiberParking) { + val events = correlator.events[event.keys[0]]!! + val acquiredConnections = events.mapNotNullTo(HashSet()) { + if (it.event.type == MonitorEventType.ConnectionAcquired) { + it.event.keys.mapNotNull { it as? Connection }.first() + } else { + null + } + } + val releasedConnections = events.mapNotNullTo(HashSet()) { + if (it.event.type == MonitorEventType.ConnectionReleased) { + it.event.keys.mapNotNull { it as? Connection }.first() + } else { + null + } + } + val leakedConnections = (acquiredConnections - releasedConnections).filter { !it.isClosed } + if (leakedConnections.isNotEmpty()) { + log.warn("Leaked open connections $leakedConnections") + } + } + } +} + +class MonitorEventCorrelator { + private val _events = HashMap>() + val events: Map> get() = _events + + fun getUnique() = events.values.toSet().associateBy { it.flatMap { it.event.keys }.toSet() } + + fun getByType() = events.entries.groupBy { it.key.javaClass } + + fun addEvent(fullMonitorEvent: FullMonitorEvent) { + val list = link(fullMonitorEvent.event.keys) + list.add(fullMonitorEvent) + for (key in fullMonitorEvent.event.keys) { + _events[key] = list + } + } + + fun link(keys: List): ArrayList { + val eventLists = HashSet>() + for (key in keys) { + val list = _events[key] + if (list != null) { + eventLists.add(list) + } + } + return when { + eventLists.isEmpty() -> ArrayList() + eventLists.size == 1 -> eventLists.first() + else -> mergeAll(eventLists) + } + } + + fun mergeAll(lists: Collection>): ArrayList { + return lists.fold(ArrayList()) { merged, next -> merge(merged, next) } + } + + fun merge(a: List, b: List): ArrayList { + val merged = ArrayList() + var aIndex = 0 + var bIndex = 0 + while (true) { + if (aIndex >= a.size) { + merged.addAll(b.subList(bIndex, b.size)) + return merged + } + if (bIndex >= b.size) { + merged.addAll(a.subList(aIndex, a.size)) + return merged + } + val aElem = a[aIndex] + val bElem = b[bIndex] + if (aElem.timestamp < bElem.timestamp) { + merged.add(aElem) + aIndex++ + } else { + merged.add(bElem) + bIndex++ + } + } + } +} \ No newline at end of file diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHook.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHook.kt new file mode 100644 index 0000000000..ebd9cf1cd5 --- /dev/null +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHook.kt @@ -0,0 +1,15 @@ +package net.corda.flowhook + +import java.lang.instrument.Instrumentation + +@Suppress("UNUSED") +class FlowHookAgent { + companion object { + @JvmStatic + fun premain(argumentsString: String?, instrumentation: Instrumentation) { + FiberMonitor.start() + instrumentation.addTransformer(Hooker(FlowHookContainer)) + } + } +} + diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt new file mode 100644 index 0000000000..d61750b00c --- /dev/null +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt @@ -0,0 +1,104 @@ +package net.corda.flowhook + +import co.paralleluniverse.fibers.Fiber +import net.corda.node.services.statemachine.ActionExecutor +import net.corda.node.services.statemachine.Event +import net.corda.node.services.statemachine.FlowFiber +import net.corda.node.services.statemachine.StateMachineState +import net.corda.node.services.statemachine.transitions.TransitionResult +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager +import rx.subjects.Subject +import java.sql.Connection + +@Suppress("UNUSED") +object FlowHookContainer { + + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber") + fun park() { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParking, keys = listOf(Fiber.currentFiber()))) + } + + @JvmStatic + @Hook("net.corda.node.services.statemachine.FlowStateMachineImpl") + fun run() { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberStarted, keys = listOf(Fiber.currentFiber()))) + } + + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber") + fun onCompleted() { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberEnded, keys = listOf(Fiber.currentFiber()))) + } + + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber") + fun onException(exception: Throwable) { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberException, keys = listOf(Fiber.currentFiber()), extra = exception)) + } + + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber") + fun onResumed() { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberResumed, keys = listOf(Fiber.currentFiber()))) + } + + @JvmStatic + @Hook("net.corda.node.utilities.DatabaseTransaction", passThis = true, position = HookPosition.After) + fun DatabaseTransaction( + transaction: Any, + isolation: Int, + threadLocal: ThreadLocal<*>, + transactionBoundaries: Subject<*, *>, + cordaPersistence: CordaPersistence + ) { + val keys = ArrayList().apply { + add(transaction) + Fiber.currentFiber()?.let { add(it) } + } + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.TransactionCreated, keys = keys)) + } + + @JvmStatic + @Hook("com.zaxxer.hikari.HikariDataSource") + fun getConnection(): (Connection) -> Unit { + val transactionOrThread = currentTransactionOrThread() + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionRequested, keys = listOf(transactionOrThread))) + return { connection -> + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionAcquired, keys = listOf(transactionOrThread, connection))) + } + } + + @JvmStatic + @Hook("com.zaxxer.hikari.pool.ProxyConnection", passThis = true, position = HookPosition.After) + fun close(connection: Any) { + connection as Connection + val transactionOrThread = currentTransactionOrThread() + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ConnectionReleased, keys = listOf(transactionOrThread, connection))) + } + + @JvmStatic + @Hook("net.corda.node.services.statemachine.TransitionExecutorImpl") + fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ) { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.ExecuteTransition, keys = listOf(fiber), extra = object { + val previousState = previousState + val event = event + val transition = transition + })) + } + + private fun currentTransactionOrThread(): Any { + return try { + DatabaseTransactionManager.currentOrNull() + } catch (exception: IllegalStateException) { + null + } ?: Thread.currentThread() + } +} diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/Hooker.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/Hooker.kt new file mode 100644 index 0000000000..9a9999f260 --- /dev/null +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/Hooker.kt @@ -0,0 +1,130 @@ +package net.corda.flowhook + +import javassist.ClassPool +import javassist.CtBehavior +import javassist.CtClass +import java.io.ByteArrayInputStream +import java.lang.instrument.ClassFileTransformer +import java.lang.reflect.Method +import java.security.ProtectionDomain + +class Hooker(hookContainer: Any) : ClassFileTransformer { + private val classPool = ClassPool.getDefault() + + private val hooks = createHooks(hookContainer) + + private fun createHooks(hookContainer: Any): Hooks { + val hooks = HashMap>>() + for (method in hookContainer.javaClass.methods) { + val hookAnnotation = method.getAnnotation(Hook::class.java) + if (hookAnnotation != null) { + val signature = if (hookAnnotation.passThis) { + if (method.parameterTypes.isEmpty() || method.parameterTypes[0] != Any::class.java) { + println("Method should accept an object as first parameter for 'this' $method") + continue + } + Signature(method.name, method.parameterTypes.toList().drop(1).map { it.canonicalName }) + } else { + Signature(method.name, method.parameterTypes.map { it.canonicalName }) + } + hooks.getOrPut(hookAnnotation.clazz) { HashMap() }.put(signature, Pair(method, hookAnnotation)) + } + } + return hooks + } + + override fun transform( + loader: ClassLoader?, + className: String, + classBeingRedefined: Class<*>?, + protectionDomain: ProtectionDomain?, + classfileBuffer: ByteArray + ): ByteArray? { + if (className.startsWith("java") || className.startsWith("sun") || className.startsWith("javassist") || className.startsWith("kotlin")) { + return null + } + return try { + val clazz = classPool.makeClass(ByteArrayInputStream(classfileBuffer)) + instrumentClass(clazz)?.toBytecode() + } catch (throwable: Throwable) { + println("SOMETHING WENT WRONG") + throwable.printStackTrace(System.out) + null + } + } + + private fun instrumentClass(clazz: CtClass): CtClass? { + val hookMethods = hooks[clazz.name] ?: return null + val usedHookMethods = HashSet() + var isAnyInstrumented = false + for (method in clazz.declaredBehaviors) { + val hookMethod = instrumentBehaviour(method, hookMethods) + if (hookMethod != null) { + isAnyInstrumented = true + usedHookMethods.add(hookMethod) + } + } + val unusedHookMethods = hookMethods.values.mapTo(HashSet()) { it.first } - usedHookMethods + if (unusedHookMethods.isNotEmpty()) { + println("Unused hook methods $unusedHookMethods") + } + return if (isAnyInstrumented) { + clazz + } else { + null + } + } + + private fun instrumentBehaviour(method: CtBehavior, methodHooks: MethodHooks): Method? { + val signature = Signature(method.name, method.parameterTypes.map { it.name }) + val (hookMethod, annotation) = methodHooks[signature] ?: return null + val invocationString = if (annotation.passThis) { + "${hookMethod.declaringClass.canonicalName}.${hookMethod.name}(this, \$\$)" + } else { + "${hookMethod.declaringClass.canonicalName}.${hookMethod.name}(\$\$)" + } + + val overriddenPosition = if (method.methodInfo.isConstructor && annotation.passThis && annotation.position == HookPosition.Before) { + println("passThis=true and position=${HookPosition.Before} for a constructor. " + + "You can only inspect 'this' at the end of the constructor! Hooking *after*.. $method") + HookPosition.After + } else { + annotation.position + } + + val insertHook: (CtBehavior.(code: String) -> Unit) = when (overriddenPosition) { + HookPosition.Before -> CtBehavior::insertBefore + HookPosition.After -> CtBehavior::insertAfter + } + when { + Function0::class.java.isAssignableFrom(hookMethod.returnType) -> { + method.addLocalVariable("after", classPool.get("kotlin.jvm.functions.Function0")) + method.insertHook("after = $invocationString;") + method.insertAfter("after.invoke();") + } + Function1::class.java.isAssignableFrom(hookMethod.returnType) -> { + method.addLocalVariable("after", classPool.get("kotlin.jvm.functions.Function1")) + method.insertHook("after = $invocationString;") + method.insertAfter("after.invoke((\$w)\$_);") + } + else -> { + method.insertHook("$invocationString;") + } + } + return hookMethod + } +} + + +enum class HookPosition { + Before, + After +} + +@Target(AnnotationTarget.FUNCTION) +annotation class Hook(val clazz: String, val position: HookPosition = HookPosition.Before, val passThis: Boolean = false) + +private data class Signature(val functionName: String, val parameterTypes: List) + +private typealias MethodHooks = Map> +private typealias Hooks = Map diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt index cce28ec508..f05849212d 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt @@ -215,3 +215,15 @@ fun rx.Observable.wrapWithDatabaseTransaction(db: CordaPersistence? } } } + +fun parserTransactionIsolationLevel(property: String?): Int = + when (property) { + "none" -> Connection.TRANSACTION_NONE + "readUncommitted" -> Connection.TRANSACTION_READ_UNCOMMITTED + "readCommitted" -> Connection.TRANSACTION_READ_COMMITTED + "repeatableRead" -> Connection.TRANSACTION_REPEATABLE_READ + "serializable" -> Connection.TRANSACTION_SERIALIZABLE + else -> { + Connection.TRANSACTION_REPEATABLE_READ + } + } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt index 5c78dc6a40..626a5433b4 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt @@ -14,11 +14,15 @@ class DatabaseTransaction( ) { val id: UUID = UUID.randomUUID() + private var _connectionCreated = false + val connectionCreated get() = _connectionCreated val connection: Connection by lazy(LazyThreadSafetyMode.NONE) { - cordaPersistence.dataSource.connection.apply { - autoCommit = false - transactionIsolation = isolation - } + cordaPersistence.dataSource.connection + .apply { + _connectionCreated = true + autoCommit = false + transactionIsolation = isolation + } } private val sessionDelegate = lazy { @@ -30,20 +34,22 @@ class DatabaseTransaction( val session: Session by sessionDelegate private lateinit var hibernateTransaction: Transaction - private val outerTransaction: DatabaseTransaction? = threadLocal.get() + val outerTransaction: DatabaseTransaction? = threadLocal.get() fun commit() { if (sessionDelegate.isInitialized()) { hibernateTransaction.commit() } - connection.commit() + if (_connectionCreated) { + connection.commit() + } } fun rollback() { if (sessionDelegate.isInitialized() && session.isOpen) { session.clear() } - if (!connection.isClosed) { + if (_connectionCreated && !connection.isClosed) { connection.rollback() } } @@ -52,7 +58,9 @@ class DatabaseTransaction( if (sessionDelegate.isInitialized() && session.isOpen) { session.close() } - connection.close() + if (_connectionCreated) { + connection.close() + } threadLocal.set(outerTransaction) if (outerTransaction == null) { transactionBoundaries.onNext(DatabaseTransactionManager.Boundary(id)) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt index 5663a78beb..d6796023b3 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt @@ -3,7 +3,7 @@ package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver import net.corda.core.serialization.* -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput import net.corda.nodeapi.internal.serialization.amqp.Envelope import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory @@ -47,17 +47,17 @@ class ListsSerializationTest { @Test fun `check list can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, listOf(1).serialize()) + val sessionData = DataSessionMessage(listOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(listOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, listOf(1, 2).serialize()) + val sessionData = DataSessionMessage(listOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(listOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptyList().serialize()) + val sessionData = DataSessionMessage(emptyList().serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(emptyList(), sessionData.payload.deserialize()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt index f726d0b97a..3bb028f437 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt @@ -6,7 +6,7 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 import net.corda.testing.SerializationEnvironmentRule import net.corda.testing.amqpSpecific @@ -41,7 +41,7 @@ class MapsSerializationTest { @Test fun `check list can be serialized as part of SessionData`() { - val sessionData = SessionData(123, smallMap.serialize()) + val sessionData = DataSessionMessage(smallMap.serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(smallMap, sessionData.payload.deserialize()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt index 243b73a803..03c4fb08c1 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt @@ -4,10 +4,10 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 -import net.corda.testing.kryoSpecific import net.corda.testing.SerializationEnvironmentRule +import net.corda.testing.kryoSpecific import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Rule @@ -34,17 +34,17 @@ class SetsSerializationTest { @Test fun `check set can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, setOf(1).serialize()) + val sessionData = DataSessionMessage(setOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(setOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, setOf(1, 2).serialize()) + val sessionData = DataSessionMessage(setOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(setOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptySet().serialize()) + val sessionData = DataSessionMessage(emptySet().serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(emptySet(), sessionData.payload.deserialize()) } diff --git a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt b/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt index 925f14928e..56f4e79845 100644 --- a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt @@ -17,11 +17,11 @@ import net.corda.nodeapi.User import net.corda.testing.* import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.driver -import net.corda.testing.node.NotarySpec import net.corda.testing.internal.performance.div import net.corda.testing.internal.performance.startPublishingFixedRateInjector import net.corda.testing.internal.performance.startReporter import net.corda.testing.internal.performance.startTightLoopInjector +import net.corda.testing.node.NotarySpec import org.junit.Before import org.junit.ClassRule import org.junit.Ignore @@ -78,7 +78,7 @@ class NodePerformanceTests : IntegrationTest() { queueBound = 50 ) { val timing = Stopwatch.createStarted().apply { - connection.proxy.startFlow(::EmptyFlow).returnValue.get() + connection.proxy.startFlow(::EmptyFlow).returnValue.getOrThrow() }.stop().elapsed(TimeUnit.MICROSECONDS) timings.add(timing) } @@ -100,13 +100,27 @@ class NodePerformanceTests : IntegrationTest() { a as NodeHandle.InProcess val metricRegistry = startReporter(shutdownManager, a.node.services.monitoringService.metrics) a.rpcClientToNode().use("A", "A") { connection -> - startPublishingFixedRateInjector(metricRegistry, 8, 5.minutes, 2000L / TimeUnit.SECONDS) { + startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { connection.proxy.startFlow(::EmptyFlow).returnValue.get() } } } } + @Test + fun `issue flow rate`() { + driver(startNodesInProcess = true, extraCordappPackagesToScan = listOf("net.corda.finance")) { + val a = startNode(rpcUsers = listOf(User("A", "A", setOf(startFlow())))).get() + a as NodeHandle.InProcess + val metricRegistry = startReporter(shutdownManager, a.node.services.monitoringService.metrics) + a.rpcClientToNode().use("A", "A") { connection -> + startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { + connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue.get() + } + } + } + } + @Test fun `self pay rate`() { val user = User("A", "A", setOf(startFlow(), startFlow())) diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt index 588cb325f8..dbbeaba860 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt @@ -1,9 +1,9 @@ package net.corda.services.messaging import net.corda.core.concurrent.CordaFuture -import net.corda.core.crypto.random63BitValue import net.corda.core.identity.CordaX500Name import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.randomOrNull import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient @@ -14,7 +14,9 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.seconds import net.corda.node.internal.Node import net.corda.node.internal.StartedNode -import net.corda.node.services.messaging.* +import net.corda.node.services.messaging.MessagingService +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.messaging.send import net.corda.node.services.transactions.RaftValidatingNotaryService import net.corda.testing.* import net.corda.testing.driver.DriverDSLExposedInterface @@ -28,6 +30,7 @@ import org.junit.Test import java.util.* import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger class P2PMessagingTest : IntegrationTest() { @@ -54,19 +57,12 @@ class P2PMessagingTest : IntegrationTest() { alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) } - val dummyTopic = "dummy.topic" val responseMessage = "response" - val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) + val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage) // Send a single request with retry - val responseFuture = with(alice.network) { - val request = TestRequest(replyTo = myAddress) - val responseFuture = onNext(dummyTopic, request.sessionID) - val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes) - send(msg, serviceAddress, retryId = request.sessionID) - responseFuture - } + val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0) crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS) // The request wasn't successful. assertThat(responseFuture.isDone).isFalse() @@ -87,22 +83,15 @@ class P2PMessagingTest : IntegrationTest() { alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) } - val dummyTopic = "dummy.topic" val responseMessage = "response" - val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) - - val sessionId = random63BitValue() + val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage) // Send a single request with retry - with(alice.network) { - val request = TestRequest(sessionId, myAddress) - val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes) - send(msg, serviceAddress, retryId = request.sessionID) - } + alice.receiveFrom(serviceAddress, retryId = 0) // Wait until the first request is received - crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS) + crashingNodes.firstRequestReceived.await() // Stop alice's node after we ensured that the first request was delivered and ignored. alice.dispose() val numberOfRequestsReceived = crashingNodes.requestsReceived.get() @@ -112,7 +101,12 @@ class P2PMessagingTest : IntegrationTest() { // Restart the node and expect a response val aliceRestarted = startAlice() - val response = aliceRestarted.network.onNext(dummyTopic, sessionId).getOrThrow(5.seconds) + + val responseFuture = openFuture() + aliceRestarted.network.runOnNextMessage("test.response") { + responseFuture.set(it.data.deserialize()) + } + val response = responseFuture.getOrThrow() assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived) assertThat(response).isEqualTo(responseMessage) @@ -138,11 +132,12 @@ class P2PMessagingTest : IntegrationTest() { ) /** - * Sets up the [distributedServiceNodes] to respond to [dummyTopic] requests. All nodes will receive requests and - * either ignore them or respond, depending on the value of [CrashingNodes.ignoreRequests], initially set to true. - * This may be used to simulate scenarios where nodes receive request messages but crash before sending back a response. + * Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and + * either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests], + * initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash + * before sending back a response. */ - private fun simulateCrashingNodes(distributedServiceNodes: List>, dummyTopic: String, responseMessage: String): CrashingNodes { + private fun simulateCrashingNodes(distributedServiceNodes: List>, responseMessage: String): CrashingNodes { val crashingNodes = CrashingNodes( requestsReceived = AtomicInteger(0), firstRequestReceived = CountDownLatch(1), @@ -151,7 +146,7 @@ class P2PMessagingTest : IntegrationTest() { distributedServiceNodes.forEach { val nodeName = it.info.chooseIdentity().name - it.network.addMessageHandler(dummyTopic) { netMessage, _ -> + it.network.addMessageHandler("test.request") { netMessage, _, handler -> crashingNodes.requestsReceived.incrementAndGet() crashingNodes.firstRequestReceived.countDown() // The node which receives the first request will ignore all requests @@ -163,9 +158,10 @@ class P2PMessagingTest : IntegrationTest() { } else { println("sending response") val request = netMessage.data.deserialize() - val response = it.network.createMessage(dummyTopic, request.sessionID, responseMessage.serialize().bytes) + val response = it.network.createMessage("test.response", responseMessage.serialize().bytes) it.network.send(response, request.replyTo) } + handler.acknowledge() } } return crashingNodes @@ -193,19 +189,41 @@ class P2PMessagingTest : IntegrationTest() { } private fun StartedNode<*>.respondWith(message: Any) { - network.addMessageHandler(javaClass.name) { netMessage, _ -> + network.addMessageHandler("test.request") { netMessage, _, handle -> val request = netMessage.data.deserialize() - val response = network.createMessage(javaClass.name, request.sessionID, message.serialize().bytes) + val response = network.createMessage("test.response", message.serialize().bytes) network.send(response, request.replyTo) + handle.acknowledge() } } - private fun StartedNode<*>.receiveFrom(target: MessageRecipients): CordaFuture { - val request = TestRequest(replyTo = network.myAddress) - return network.sendRequest(javaClass.name, request, target) + private fun StartedNode<*>.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture { + val response = openFuture() + network.runOnNextMessage("test.response") { netMessage -> + response.set(netMessage.data.deserialize()) + } + network.send("test.request", TestRequest(replyTo = network.myAddress), target, retryId = retryId) + return response + } + + /** + * Registers a handler for the given topic and session that runs the given callback with the message and then removes + * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback + * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler]. + * + * @param topic identifier for the topic and session to listen for messages arriving on. + */ + inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) { + val consumed = AtomicBoolean() + addMessageHandler(topic) { msg, reg, handle -> + removeMessageHandler(reg) + check(!consumed.getAndSet(true)) { "Called more than once" } + check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" } + callback(msg) + handle.acknowledge() + } } @CordaSerializable - private data class TestRequest(override val sessionID: Long = random63BitValue(), - override val replyTo: SingleMessageRecipient) : ServiceRequestMessage + private data class TestRequest(val replyTo: SingleMessageRecipient) } diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 2f536769cb..de4e986dab 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -12,6 +12,7 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.crypto.SignedData import net.corda.core.crypto.sign +import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.* import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party @@ -280,6 +281,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, checkpointStorage, serverThread, database, + newSecureRandom(), busyNodeLatch, cordappLoader.appClassLoader ) @@ -556,7 +558,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val database = configureDatabase(props, configuration.database, identityService, schemaService) // Now log the vendor string as this will also cause a connection to be tested eagerly. database.transaction { - log.info("Connected to ${database.dataSource.connection.metaData.databaseProductName} database.") + log.info("Connected to ${connection.metaData.databaseProductName} database.") } runOnStop += database::close return database.transaction { diff --git a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt index f259512109..3b86147c4e 100644 --- a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt +++ b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt @@ -12,4 +12,3 @@ sealed class InitiatedFlowFactory> { val appName: String, override val factory: (FlowSession) -> F) : InitiatedFlowFactory() } - diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index 867e9d6c65..227780cb4c 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -1,42 +1,28 @@ package net.corda.node.services.api -import net.corda.core.crypto.SecureHash +import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.statemachine.FlowStateMachineImpl +import net.corda.node.services.statemachine.Checkpoint +import java.util.stream.Stream /** * Thread-safe storage of fiber checkpoints. */ interface CheckpointStorage { - /** * Add a new checkpoint to the store. */ - fun addCheckpoint(checkpoint: Checkpoint) + fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes) /** * Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist * in the store. Doing so will throw an [IllegalArgumentException]. */ - fun removeCheckpoint(checkpoint: Checkpoint) + fun removeCheckpoint(id: StateMachineRunId) /** - * Allows the caller to process safely in a thread safe fashion the set of all checkpoints. - * The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management. - * Return false from the block to terminate further iteration. + * Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the + * underlying database connection is open, so any processing should happen before it is closed. */ - fun forEach(block: (Checkpoint) -> Boolean) - -} - -// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). -class Checkpoint(val serializedFiber: SerializedBytes>) { - - val id: SecureHash get() = serializedFiber.hash - - override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id - - override fun hashCode(): Int = id.hashCode() - - override fun toString(): String = "${javaClass.simpleName}(id=$id)" + fun getAllCheckpoints(): Stream>> } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 468de4d8f5..898c881e90 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -1,18 +1,15 @@ package net.corda.node.services.messaging -import net.corda.core.concurrent.CordaFuture +import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name -import net.corda.core.internal.concurrent.openFuture -import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.utilities.ByteSequence +import net.corda.node.services.statemachine.DeduplicationId import java.time.Instant -import java.util.* -import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.ThreadSafe /** @@ -27,29 +24,6 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe interface MessagingService { - companion object { - /** - * Session ID to use for services listening for the first message in a session (before a - * specific session ID has been established). - */ - val DEFAULT_SESSION_ID = 0L - } - - /** - * The provided function will be invoked for each received message whose topic matches the given string. The callback - * will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result. - * - * The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler]. - * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister - * itself and yet addMessageHandler hasn't returned the handle yet. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ - fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration - /** * The provided function will be invoked for each received message whose topic and session matches. The callback * will run on the main server thread provided when the messaging service is constructed, and a database @@ -59,9 +33,9 @@ interface MessagingService { * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister * itself and yet addMessageHandler hasn't returned the handle yet. * - * @param topicSession identifier for the topic and session to listen for messages arriving on. + * @param topic identifier for the topic to listen for messages arriving on. */ - fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration + fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration /** * Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once @@ -110,8 +84,6 @@ interface MessagingService { * implementation. * * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. - * @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id. - * Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary. * @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed * by the broker. Note that if specified [send] itself may return earlier than the commit. */ @@ -123,9 +95,9 @@ interface MessagingService { /** * Returns an initialised [Message] with the current time, etc, already filled in. * - * @param topicSession identifier for the topic and session the message is sent to. + * @param topic identifier for the topic the message is sent to. */ - fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message + fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom())): Message /** Given information about either a specific node or a service returns its corresponding address */ fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients @@ -134,86 +106,12 @@ interface MessagingService { val myAddress: SingleMessageRecipient } -/** - * Returns an initialised [Message] with the current time, etc, already filled in. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * Must not be blank. - * @param sessionID identifier for the session the message is part of. For messages sent to services before the - * construction of a session, use [DEFAULT_SESSION_ID]. - */ -fun MessagingService.createMessage(topic: String, sessionID: Long = MessagingService.DEFAULT_SESSION_ID, data: ByteArray): Message - = createMessage(TopicSession(topic, sessionID), data) -/** - * Registers a handler for the given topic and session ID that runs the given callback with the message and then removes - * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback - * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler], as the handler is - * automatically deregistered before the callback runs. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ -fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (ReceivedMessage) -> Unit) - = runOnNextMessage(TopicSession(topic, sessionID), callback) - -/** - * Registers a handler for the given topic and session that runs the given callback with the message and then removes - * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback - * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler]. - * - * @param topicSession identifier for the topic and session to listen for messages arriving on. - */ -inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (ReceivedMessage) -> Unit) { - val consumed = AtomicBoolean() - addMessageHandler(topicSession) { msg, reg -> - removeMessageHandler(reg) - check(!consumed.getAndSet(true)) { "Called more than once" } - check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" } - callback(msg) - } -} - -/** - * Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId. - * The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future. - */ -fun MessagingService.onNext(topic: String, sessionId: Long): CordaFuture { - val messageFuture = openFuture() - runOnNextMessage(topic, sessionId) { message -> - messageFuture.capture { - uncheckedCast(message.data.deserialize()) - } - } - return messageFuture -} - -fun MessagingService.send(topic: String, sessionID: Long, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID()) { - send(TopicSession(topic, sessionID), payload, to, uuid) -} - -fun MessagingService.send(topicSession: TopicSession, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID(), retryId: Long? = null) { - send(createMessage(topicSession, payload.serialize().bytes, uuid), to, retryId) -} +fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: DeduplicationId = DeduplicationId.createRandom(newSecureRandom()), retryId: Long? = null) + = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId) interface MessageHandlerRegistration -/** - * An identifier for the endpoint [MessagingService] message handlers listen at. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ -@CordaSerializable -data class TopicSession(val topic: String, val sessionID: Long = MessagingService.DEFAULT_SESSION_ID) { - fun isBlank() = topic.isBlank() && sessionID == MessagingService.DEFAULT_SESSION_ID - override fun toString(): String = "$topic.$sessionID" -} - /** * A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in * Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor @@ -226,10 +124,10 @@ data class TopicSession(val topic: String, val sessionID: Long = MessagingServic */ @CordaSerializable interface Message { - val topicSession: TopicSession - val data: ByteArray + val topic: String + val data: ByteSequence val debugTimestamp: Instant - val uniqueMessageId: UUID + val uniqueMessageId: DeduplicationId } // TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that. @@ -248,3 +146,20 @@ object TopicStringValidator { /** @throws IllegalArgumentException if the given topic contains invalid characters */ fun check(tag: String) = require(regex.matcher(tag).matches()) } + +/** + * Represents a to-be-acknowledged message. It has an associated deduplication ID. + */ +interface AcknowledgeHandle { + /** + * Acknowledge the message. + */ + fun acknowledge() + + /** + * Store the deduplication ID. TODO this should be moved into the flow state machine completely. + */ + fun persistDeduplicationId() +} + +typealias MessageHandler = (ReceivedMessage, MessageHandlerRegistration, AcknowledgeHandle) -> Unit diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index f4976610d8..a879af3b56 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -10,13 +10,11 @@ import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.sequence -import net.corda.core.utilities.trace +import net.corda.core.utilities.* import net.corda.node.VersionInfo import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.statemachine.StateMachineManagerImpl +import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.services.statemachine.FlowMessagingImpl import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.PersistentMap @@ -33,15 +31,16 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage import java.security.PublicKey import java.time.Instant import java.util.* -import java.util.concurrent.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CountDownLatch +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit import javax.annotation.concurrent.ThreadSafe import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Id import javax.persistence.Lob -// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox - /** * This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product. * Artemis is a message queue broker and here we run a client connecting to the specified broker instance @@ -77,20 +76,19 @@ class P2PMessagingClient(config: NodeConfiguration, // that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid // confusion. private val topicProperty = SimpleString("platform-topic") - private val sessionIdProperty = SimpleString("session-id") private val cordaVendorProperty = SimpleString("corda-vendor") private val releaseVersionProperty = SimpleString("release-version") private val platformVersionProperty = SimpleString("platform-version") private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() private val messageMaxRetryCount: Int = 3 - fun createProcessedMessage(): AppendOnlyPersistentMap { + fun createProcessedMessages(): AppendOnlyPersistentMap { return AppendOnlyPersistentMap( - toPersistentEntityKey = { it.toString() }, - fromPersistentEntity = { Pair(UUID.fromString(it.uuid), it.insertionTime) }, - toPersistentEntity = { key: UUID, value: Instant -> + toPersistentEntityKey = { it.toString }, + fromPersistentEntity = { Pair(DeduplicationId(it.id), it.insertionTime) }, + toPersistentEntity = { key: DeduplicationId, value: Instant -> ProcessedMessage().apply { - uuid = key.toString() + id = key.toString insertionTime = value } }, @@ -118,9 +116,9 @@ class P2PMessagingClient(config: NodeConfiguration, ) } - private class NodeClientMessage(override val topicSession: TopicSession, override val data: ByteArray, override val uniqueMessageId: UUID) : Message { + private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId) : Message { override val debugTimestamp: Instant = Instant.now() - override fun toString() = "$topicSession#${String(data)}" + override fun toString() = "$topic#${String(data.bytes)}" } } @@ -136,8 +134,7 @@ class P2PMessagingClient(config: NodeConfiguration, private val scheduledMessageRedeliveries = ConcurrentHashMap>() /** A registration to handle messages of different types */ - data class Handler(val topicSession: TopicSession, - val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration + data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration private val cordaVendor = SimpleString(versionInfo.vendor) private val releaseVersion = SimpleString(versionInfo.releaseVersion) @@ -148,16 +145,17 @@ class P2PMessagingClient(config: NodeConfiguration, private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong() private val artemis = ArtemisMessagingClient(config, serverAddress) private val state = ThreadBox(InnerState()) - private val handlers = CopyOnWriteArrayList() - private val processedMessages = createProcessedMessage() + private val handlers = ConcurrentHashMap() + + private val processedMessages = createProcessedMessages() @Entity @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") class ProcessedMessage( @Id - @Column(name = "message_id", length = 36) - var uuid: String = "", + @Column(name = "message_id", length = 64) + var id: String = "", @Column(name = "insertion_time") var insertionTime: Instant = Instant.now() @@ -167,7 +165,7 @@ class P2PMessagingClient(config: NodeConfiguration, @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry") class RetryMessage( @Id - @Column(name = "message_id", length = 36) + @Column(name = "message_id", length = 64) var key: Long = 0, @Lob @@ -214,22 +212,7 @@ class P2PMessagingClient(config: NodeConfiguration, val message: ReceivedMessage? = artemisToCordaMessage(artemisMessage) if (message != null) - deliver(message) - - // Ack the message so it won't be redelivered. We should only really do this when there were no - // transient failures. If we caught an exception in the handler, we could back off and retry delivery - // a few times before giving up and redirecting the message to a dead-letter address for admin or - // developer inspection. Artemis has the features to do this for us, we just need to enable them. - // - // TODO: Setup Artemis delayed redelivery and dead letter addresses. - // - // ACKing a message calls back into the session which isn't thread safe, so we have to ensure it - // doesn't collide with a send here. Note that stop() could have been called whilst we were - // processing a message but if so, it'll be parked waiting for us to count down the latch, so - // the session itself is still around and we can still ack messages as a result. - state.locked { - artemisMessage.acknowledge() - } + deliver(artemisMessage, message) return true } @@ -255,14 +238,13 @@ class P2PMessagingClient(config: NodeConfiguration, private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? { try { val topic = message.required(topicProperty) { getStringProperty(it) } - val sessionID = message.required(sessionIdProperty) { getLongProperty(it) } val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" } val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) } // Use the magic deduplication property built into Artemis as our message identity too - val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { UUID.fromString(message.getStringProperty(it)) } - log.trace { "Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid" } + val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) } + log.trace { "Received message from: ${message.address} user: $user topic: $topic id: $uniqueMessageId" } - return ArtemisReceivedMessage(TopicSession(topic, sessionID), CordaX500Name.parse(user), platformVersion, uuid, message) + return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uniqueMessageId, message) } catch (e: Exception) { log.error("Unable to process message, ignoring it: $message", e) return null @@ -274,21 +256,19 @@ class P2PMessagingClient(config: NodeConfiguration, return extractor(key) } - private class ArtemisReceivedMessage(override val topicSession: TopicSession, + private class ArtemisReceivedMessage(override val topic: String, override val peer: CordaX500Name, override val platformVersion: Int, - override val uniqueMessageId: UUID, + override val uniqueMessageId: DeduplicationId, private val message: ClientMessage) : ReceivedMessage { - override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } } + override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) } override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) - override fun toString() = "${topicSession.topic}#${data.sequence()}" + override fun toString() = "$topic#$data" } - private fun deliver(msg: ReceivedMessage): Boolean { + private fun deliver(artemisMessage: ClientMessage, msg: ReceivedMessage) { state.checkNotLocked() - // Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added - // or removed whilst the filter is executing will not affect anything. - val deliverTo = handlers.filter { it.topicSession.isBlank() || it.topicSession == msg.topicSession } + val deliverTo = handlers[msg.topic] try { // This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will // be slow, and Artemis can handle that case intelligently. We don't just invoke the handler @@ -298,31 +278,34 @@ class P2PMessagingClient(config: NodeConfiguration, // // Note that handlers may re-enter this class. We aren't holding any locks and methods like // start/run/stop have re-entrancy assertions at the top, so it is OK. - nodeExecutor.fetchFrom { - database.transaction { - if (msg.uniqueMessageId in processedMessages) { - log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topicSession}" } - } else { - if (deliverTo.isEmpty()) { - // TODO: Implement dead letter queue, and send it there. - log.warn("Received message ${msg.uniqueMessageId} for ${msg.topicSession} that doesn't have any registered handlers yet") - } else { - callHandlers(msg, deliverTo) - } - // TODO We will at some point need to decide a trimming policy for the id's + if (deliverTo != null) { + val isDuplicate = database.transaction { msg.uniqueMessageId in processedMessages } + if (isDuplicate) { + log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" } + return + } + val acknowledgeHandle = object : AcknowledgeHandle { + override fun persistDeduplicationId() { processedMessages[msg.uniqueMessageId] = Instant.now() } + + // ACKing a message calls back into the session which isn't thread safe, so we have to ensure it + // doesn't collide with a send here. Note that stop() could have been called whilst we were + // processing a message but if so, it'll be parked waiting for us to count down the latch, so + // the session itself is still around and we can still ack messages as a result. + override fun acknowledge() { + state.locked { + artemisMessage.individualAcknowledge() + artemis.started!!.session.commit() + } + } } + deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandle) + } else { + log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet") } } catch (e: Exception) { - log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e) - } - return true - } - - private fun callHandlers(msg: ReceivedMessage, deliverTo: List) { - for (handler in deliverTo) { - handler.callback(msg, handler) + log.error("Caught exception whilst executing message handler for ${msg.topic}", e) } } @@ -370,20 +353,19 @@ class P2PMessagingClient(config: NodeConfiguration, putStringProperty(cordaVendorProperty, cordaVendor) putStringProperty(releaseVersionProperty, releaseVersion) putIntProperty(platformVersionProperty, versionInfo.platformVersion) - putStringProperty(topicProperty, SimpleString(message.topicSession.topic)) - putLongProperty(sessionIdProperty, message.topicSession.sessionID) - writeBodyBufferBytes(message.data) + putStringProperty(topicProperty, SimpleString(message.topic)) + writeBodyBufferBytes(message.data.bytes) // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString())) + putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString)) // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended - if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) { + if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) { putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) } } log.trace { - "Send to: $mqAddress topic: ${message.topicSession.topic} " + - "sessionID: ${message.topicSession.sessionID} uuid: ${message.uniqueMessageId}" + "Send to: $mqAddress topic: ${message.topic} " + + "sessionID: ${message.topic} id: ${message.uniqueMessageId}" } artemis.producer.send(mqAddress, artemisMessage) retryId?.let { @@ -467,30 +449,26 @@ class P2PMessagingClient(config: NodeConfiguration, } } - override fun addMessageHandler(topic: String, - sessionID: Long, - callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { - return addMessageHandler(TopicSession(topic, sessionID), callback) - } - - override fun addMessageHandler(topicSession: TopicSession, - callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { - require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." } - val handler = Handler(topicSession, callback) - handlers.add(handler) - return handler + override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { + require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." } + handlers.compute(topic) { _, handler -> + if (handler != null) { + throw IllegalStateException("Cannot add another acking handler for $topic, there is already an acking one") + } + callback + } + return HandlerRegistration(topic, callback) } override fun removeMessageHandler(registration: MessageHandlerRegistration) { - handlers.remove(registration) + registration as HandlerRegistration + handlers.remove(registration.topic) } - override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { - // TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying. - return NodeClientMessage(topicSession, data, uuid) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId): Message { + return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId) } - // TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port) override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { return when (partyInfo) { is PartyInfo.SingleNode -> NodeAddress(partyInfo.party.owningKey, partyInfo.addresses.first()) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt index c7033b6baf..aa9fea626c 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt @@ -27,6 +27,7 @@ class RPCMessagingClient(private val config: SSLConfiguration, serverAddress: Ne } fun stop() = synchronized(this) { + rpcServer?.close() artemis.stop() } } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt deleted file mode 100644 index 15c68a3b66..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt +++ /dev/null @@ -1,27 +0,0 @@ -package net.corda.node.services.messaging - -import net.corda.core.concurrent.CordaFuture -import net.corda.core.messaging.MessageRecipients -import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.serialization.CordaSerializable - -/** - * Abstract superclass for request messages sent to services which expect a reply. - */ -@CordaSerializable -interface ServiceRequestMessage { - val sessionID: Long - val replyTo: SingleMessageRecipient -} - -/** - * Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response. - * @param R The type of the response. - */ -fun MessagingService.sendRequest(topic: String, - request: ServiceRequestMessage, - target: MessageRecipients): CordaFuture { - val responseFuture = onNext(topic, request.sessionID) - send(topic, MessagingService.DEFAULT_SESSION_ID, request, target) - return responseFuture -} diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index 5e92461c18..e0db523b07 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,10 +1,14 @@ package net.corda.node.services.persistence +import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.statemachine.Checkpoint +import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.currentDBSession +import java.util.* +import java.util.stream.Stream import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Id @@ -27,32 +31,29 @@ class DBCheckpointStorage : CheckpointStorage { var checkpoint: ByteArray = ByteArray(0) ) - override fun addCheckpoint(checkpoint: Checkpoint) { - currentDBSession().save(DBCheckpoint().apply { - checkpointId = checkpoint.id.toString() - this.checkpoint = checkpoint.serializedFiber.bytes + override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes) { + currentDBSession().saveOrUpdate(DBCheckpoint().apply { + checkpointId = id.uuid.toString() + this.checkpoint = checkpoint.bytes }) } - override fun removeCheckpoint(checkpoint: Checkpoint) { - val session = currentDBSession() + override fun removeCheckpoint(id: StateMachineRunId) { + val session = DatabaseTransactionManager.current().session val criteriaBuilder = session.criteriaBuilder val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) val root = delete.from(DBCheckpoint::class.java) - delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), checkpoint.id.toString())) + delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), id.uuid.toString())) session.createQuery(delete).executeUpdate() } - override fun forEach(block: (Checkpoint) -> Boolean) { + override fun getAllCheckpoints(): Stream>> { val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) val root = criteriaQuery.from(DBCheckpoint::class.java) criteriaQuery.select(root) - for (row in session.createQuery(criteriaQuery).resultList) { - val checkpoint = Checkpoint(SerializedBytes(row.checkpoint)) - if (!block(checkpoint)) { - break - } + return session.createQuery(criteriaQuery).stream().map { + StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes(it.checkpoint) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index f8c887d073..cb855ded3e 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -1,13 +1,20 @@ package net.corda.node.services.persistence +import net.corda.core.concurrent.CordaFuture +import net.corda.core.crypto.SecureHash +import net.corda.core.internal.ThreadBox import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed -import net.corda.core.crypto.SecureHash +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.* +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.WritableTransactionStorage -import net.corda.node.utilities.* +import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction @@ -48,22 +55,37 @@ class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsTok } } - private val txStorage = createTransactionsMap() + private val txStorage = ThreadBox(createTransactionsMap()) override fun addTransaction(transaction: SignedTransaction): Boolean = - txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply { - updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) + txStorage.locked { + addWithDuplicatesAllowed(transaction.id, transaction).apply { + updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) + } } - override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id] + override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage.content[id] private val updatesPublisher = PublishSubject.create().toSerialized() override val updates: Observable = updatesPublisher.wrapWithDatabaseTransaction() - override fun track(): DataFeed, SignedTransaction> = - DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + override fun track(): DataFeed, SignedTransaction> { + return txStorage.locked { + DataFeed(allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + } + } + + override fun trackTransaction(id: SecureHash): CordaFuture { + return txStorage.locked { + val existingTransaction = get(id) + if (existingTransaction == null) { + updatesPublisher.filter { it.id == id }.toFuture() + } else { + doneFuture(existingTransaction) + } + } + } @VisibleForTesting - val transactions: Iterable - get() = txStorage.allPersisted().map { it.second }.toList() + val transactions: Iterable get() = txStorage.content.allPersisted().map { it.second }.toList() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt new file mode 100644 index 0000000000..40ffda500d --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -0,0 +1,126 @@ +package net.corda.node.services.statemachine + +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.node.services.messaging.AcknowledgeHandle +import java.time.Instant + +/** + * [Action]s are reified IO actions to execute as part of state machine transitions. + */ +sealed class Action { + + /** + * Track a transaction hash and notify the state machine once the corresponding transaction has committed. + */ + data class TrackTransaction(val hash: SecureHash) : Action() + + /** + * Send an initial session message to [party]. + */ + data class SendInitial( + val party: Party, + val initialise: InitialSessionMessage, + val deduplicationId: DeduplicationId + ) : Action() + + /** + * Send a session message to a [peerParty] with which we have an established session. + */ + data class SendExisting( + val peerParty: Party, + val message: ExistingSessionMessage, + val deduplicationId: DeduplicationId + ) : Action() + + /** + * Persist the specified [checkpoint]. + */ + data class PersistCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) : Action() + + /** + * Remove the checkpoint corresponding to [id]. + */ + data class RemoveCheckpoint(val id: StateMachineRunId) : Action() + + /** + * Persist the deduplication IDs of [acknowledgeHandles]. + */ + data class PersistDeduplicationIds(val acknowledgeHandles: List) : Action() + + /** + * Acknowledge messages in [acknowledgeHandles]. + */ + data class AcknowledgeMessages(val acknowledgeHandles: List) : Action() + + /** + * Propagate [errorMessages] to [sessions]. + * @param sessions a map from source session IDs to initiated sessions. + */ + data class PropagateErrors( + val errorMessages: List, + val sessions: List + ) : Action() + + /** + * Create a session binding from [sessionId] to [flowId] to allow routing of incoming messages. + */ + data class AddSessionBinding(val flowId: StateMachineRunId, val sessionId: SessionId) : Action() + + /** + * Remove the session bindings corresponding to [sessionIds]. + */ + data class RemoveSessionBindings(val sessionIds: Set) : Action() + + /** + * Signal that the flow corresponding to [flowId] is considered started. + */ + data class SignalFlowHasStarted(val flowId: StateMachineRunId) : Action() + + /** + * Remove the flow corresponding to [flowId]. + */ + data class RemoveFlow( + val flowId: StateMachineRunId, + val removalReason: FlowRemovalReason, + val lastState: StateMachineState + ) : Action() + + /** + * Schedule [event] to self. + */ + data class ScheduleEvent(val event: Event) : Action() + + /** + * Sleep until [time]. + */ + data class SleepUntil(val time: Instant) : Action() + + /** + * Create a new database transaction. + */ + object CreateTransaction : Action() { override fun toString() = "CreateTransaction" } + + /** + * Roll back the current database transaction. + */ + object RollbackTransaction : Action() { override fun toString() = "RollbackTransaction" } + + /** + * Commit the current database transaction. + */ + object CommitTransaction : Action() { override fun toString() = "CommitTransaction" } +} + +/** + * Reason for flow removal. + */ +sealed class FlowRemovalReason { + data class OrderlyFinish(val flowReturnValue: Any?) : FlowRemovalReason() + data class ErrorFinish(val flowErrors: List) : FlowRemovalReason() + object SoftShutdown : FlowRemovalReason() { override fun toString() = "SoftShutdown" } + // TODO Should we remove errored flows? How will the flow hospital work? Perhaps keep them in memory for a while, flush + // them after a timeout, reload them on flow hospital request. In any case if we ever want to remove them + // (e.g. temporarily) then add a case for that here. +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt new file mode 100644 index 0000000000..9ba0881931 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutor.kt @@ -0,0 +1,15 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable + +/** + * An executor of a single [Action]. + */ +interface ActionExecutor { + /** + * Execute [action] by [fiber]. + * Precondition: [executeAction] is run inside an open database transaction. + */ + @Suspendable + fun executeAction(fiber: FlowFiber, action: Action) +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt new file mode 100644 index 0000000000..6449a8b090 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -0,0 +1,190 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.internal.concurrent.thenMatch +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.serialize +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager +import java.time.Duration +import java.time.Instant +import java.util.concurrent.TimeUnit + +/** + * This is the bottom execution engine of flow side-effects. + */ +class ActionExecutorImpl( + private val services: ServiceHubInternal, + private val checkpointStorage: CheckpointStorage, + private val flowMessaging: FlowMessaging, + private val stateMachineManager: StateMachineManagerInternal, + private val checkpointSerializationContext: SerializationContext +) : ActionExecutor { + + private companion object { + val log = contextLogger() + } + + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + log.trace { "Flow ${fiber.id} executing $action" } + return when (action) { + is Action.TrackTransaction -> executeTrackTransaction(fiber, action) + is Action.PersistCheckpoint -> executePersistCheckpoint(action) + is Action.PersistDeduplicationIds -> executePersistDeduplicationIds(action) + is Action.AcknowledgeMessages -> executeAcknowledgeMessages(action) + is Action.PropagateErrors -> executePropagateErrors(action) + is Action.ScheduleEvent -> executeScheduleEvent(fiber, action) + is Action.SleepUntil -> executeSleepUntil(action) + is Action.RemoveCheckpoint -> executeRemoveCheckpoint(action) + is Action.SendInitial -> executeSendInitial(action) + is Action.SendExisting -> executeSendExisting(action) + is Action.AddSessionBinding -> executeAddSessionBinding(action) + is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action) + is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action) + is Action.RemoveFlow -> executeRemoveFlow(action) + is Action.CreateTransaction -> executeCreateTransaction() + is Action.RollbackTransaction -> executeRollbackTransaction() + is Action.CommitTransaction -> executeCommitTransaction() + } + } + + @Suspendable + private fun executeTrackTransaction(fiber: FlowFiber, action: Action.TrackTransaction) { + services.validatedTransactions.trackTransaction(action.hash).thenMatch( + success = { transaction -> + fiber.scheduleEvent(Event.TransactionCommitted(transaction)) + }, + failure = { exception -> + fiber.scheduleEvent(Event.Error(exception)) + } + ) + } + + @Suspendable + private fun executePersistCheckpoint(action: Action.PersistCheckpoint) { + val checkpointBytes = serializeCheckpoint(action.checkpoint) + checkpointStorage.addCheckpoint(action.id, checkpointBytes) + } + + @Suspendable + private fun executePersistDeduplicationIds(action: Action.PersistDeduplicationIds) { + for (handle in action.acknowledgeHandles) { + handle.persistDeduplicationId() + } + } + + @Suspendable + private fun executeAcknowledgeMessages(action: Action.AcknowledgeMessages) { + action.acknowledgeHandles.forEach { + it.acknowledge() + } + } + + @Suspendable + private fun executePropagateErrors(action: Action.PropagateErrors) { + action.errorMessages.forEach { error -> + val exception = error.flowException + log.debug("Propagating error", exception) + } + val pendingSendAcks = CountUpDownLatch(0) + for (sessionState in action.sessions) { + // We cannot propagate if the session isn't live. + if (sessionState.initiatedState !is InitiatedSessionState.Live) { + continue + } + // Don't propagate errors to the originating session + for (errorMessage in action.errorMessages) { + val sinkSessionId = sessionState.initiatedState.peerSinkSessionId + val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) + val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) + pendingSendAcks.countUp() + flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) { + pendingSendAcks.countDown() + } + } + } + // TODO we simply block here, perhaps this should be explicit in the worker state + pendingSendAcks.await() + } + + @Suspendable + private fun executeScheduleEvent(fiber: FlowFiber, action: Action.ScheduleEvent) { + fiber.scheduleEvent(action.event) + } + + @Suspendable + private fun executeSleepUntil(action: Action.SleepUntil) { + // TODO introduce explicit sleep state + wakeup event instead of relying on Fiber.sleep. This is so shutdown + // conditions may "interrupt" the sleep instead of waiting until wakeup. + val duration = Duration.between(Instant.now(), action.time) + Fiber.sleep(duration.toNanos(), TimeUnit.NANOSECONDS) + } + + @Suspendable + private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) { + checkpointStorage.removeCheckpoint(action.id) + } + + @Suspendable + private fun executeSendInitial(action: Action.SendInitial) { + flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId, null) + } + + @Suspendable + private fun executeSendExisting(action: Action.SendExisting) { + flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId, null) + } + + @Suspendable + private fun executeAddSessionBinding(action: Action.AddSessionBinding) { + stateMachineManager.addSessionBinding(action.flowId, action.sessionId) + } + + @Suspendable + private fun executeRemoveSessionBindings(action: Action.RemoveSessionBindings) { + stateMachineManager.removeSessionBindings(action.sessionIds) + } + + @Suspendable + private fun executeSignalFlowHasStarted(action: Action.SignalFlowHasStarted) { + stateMachineManager.signalFlowHasStarted(action.flowId) + } + + @Suspendable + private fun executeRemoveFlow(action: Action.RemoveFlow) { + stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState) + } + + @Suspendable + private fun executeCreateTransaction() { + if (DatabaseTransactionManager.currentOrNull() != null) { + throw IllegalStateException("Refusing to create a second transaction") + } + DatabaseTransactionManager.newTransaction() + } + + @Suspendable + private fun executeRollbackTransaction() { + DatabaseTransactionManager.currentOrNull()?.close() + } + + @Suspendable + private fun executeCommitTransaction() { + try { + DatabaseTransactionManager.current().commit() + } finally { + DatabaseTransactionManager.current().close() + DatabaseTransactionManager.setThreadLocalTx(null) + } + } + + private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes { + return checkpoint.serialize(context = checkpointSerializationContext) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt new file mode 100644 index 0000000000..26286bd294 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/CountUpDownLatch.kt @@ -0,0 +1,66 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.strands.concurrent.AbstractQueuedSynchronizer +import co.paralleluniverse.fibers.Suspendable + +/** + * Quasar-compatible latch that may be incremented. + */ +class CountUpDownLatch(initialValue: Int) { + + // See quasar CountDownLatch + private class Sync(initialValue: Int) : AbstractQueuedSynchronizer() { + init { + state = initialValue + } + + override fun tryAcquireShared(arg: Int): Int { + if (arg >= 0) { + return if (state == arg) 1 else -1 + } else { + return if (state <= -arg) 1 else -1 + } + } + + override fun tryReleaseShared(arg: Int): Boolean { + while (true) { + val c = state + if (c == 0) + return false + val nextc = c - Math.min(c, arg) + if (compareAndSetState(c, nextc)) + return nextc == 0 + } + } + + fun increment() { + while (true) { + val c = state + val nextc = c + 1 + if (compareAndSetState(c, nextc)) + return + } + } + } + + private val sync = Sync(initialValue) + + @Suspendable + fun await() { + sync.acquireSharedInterruptibly(0) + } + + @Suspendable + fun awaitLessThanOrEqual(number: Int) { + sync.acquireSharedInterruptibly(number) + } + + fun countDown(number: Int = 1) { + require(number > 0) + sync.releaseShared(number) + } + + fun countUp() { + sync.increment() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt new file mode 100644 index 0000000000..5b60d5ad3d --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt @@ -0,0 +1,47 @@ +package net.corda.node.services.statemachine + +import java.security.SecureRandom + +/** + * A deduplication ID of a flow message. + */ +data class DeduplicationId(val toString: String) { + companion object { + /** + * Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it, + * unless we persist the ID somehow. + */ + fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}") + + /** + * Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of + * creating IDs in case the message-generating flow logic is replayed on hard failure. + * + * A normal deduplication ID consists of: + * 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the + * initiator's session ID. + * 2. The number of *clean* suspends since the start of the flow. + * 3. An optional additional index, for cases where several messages are sent as part of the state transition. + * Note that care must be taken with this index, it must be a deterministic counter. For example a naive + * iteration over a HashMap will produce a different list of indeces than a previous run, causing the + * message-id map to change, which means deduplication will not happen correctly. + */ + fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId { + return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index") + } + + /** + * Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal + * IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the + * dirtiness state to be thrown away for resumption. + * + * An error deduplication ID consists of: + * 1. The error's ID. This is a unique value per "source" of error and is propagated. + * See [net.corda.core.flows.IdentifiableException]. + * 2. The recipient's session ID. + */ + fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId { + return DeduplicationId("E-$errorId-${recipientSessionId.toLong}") + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt new file mode 100644 index 0000000000..87d7712dfb --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -0,0 +1,117 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.SignedTransaction +import net.corda.node.services.messaging.AcknowledgeHandle + +/** + * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from + * outside (e.g. in case of message delivery or external event). + */ +sealed class Event { + /** + * Check the current state for pending work. For example if the flow is waiting for a message from a particular + * session this event may cause a flow resume if we have a corresponding message. In general the state machine + * should be idempotent in the [DoRemainingWork] event, meaning a second subsequent event shouldn't modify the state + * or produce [Action]s. + */ + object DoRemainingWork : Event() { override fun toString() = "DoRemainingWork" } + + /** + * Deliver a session message. + * @param sessionMessage the message itself. + * @param acknowledgeHandle the handle to acknowledge the message after checkpointing. + * @param sender the sender [Party]. + */ + data class DeliverSessionMessage( + val sessionMessage: ExistingSessionMessage, + val acknowledgeHandle: AcknowledgeHandle, + val sender: Party + ) : Event() + + /** + * Signal that an error has happened. This may be due to an uncaught exception in the flow or some external error. + * @param exception the exception itself. + */ + data class Error(val exception: Throwable) : Event() + + /** + * Signal that a ledger transaction has committed. This is an event completing a [FlowIORequest.WaitForLedgerCommit] + * suspension. + * @param transaction the transaction that was committed. + */ + data class TransactionCommitted(val transaction: SignedTransaction) : Event() + + /** + * Trigger a soft shutdown, removing the flow as soon as possible. This causes the flow to be removed as soon as + * this event is processed. Note that on restart the flow will resume as normal. + */ + object SoftShutdown : Event() { override fun toString() = "SoftShutdown" } + + /** + * Start error propagation on a errored flow. This may be triggered by e.g. a [FlowHospital]. + */ + object StartErrorPropagation : Event() { override fun toString() = "StartErrorPropagation" } + + /** + * + * Scheduled by the flow. + * + * Initiate a flow. This causes a new session object to be created and returned to the flow. Note that no actual + * communication takes place at this time, only on the first send/receive operation on the session. + * @param party the [Party] to create a session with. + */ + data class InitiateFlow(val party: Party) : Event() + + /** + * Signal the entering into a subflow. + * + * Scheduled and executed by the flow. + * + * @param subFlowClass the [Class] of the subflow, to be used to determine whether it's Initiating or inlined. + */ + data class EnterSubFlow(val subFlowClass: Class>) : Event() + + /** + * Signal the leaving of a subflow. + * + * Scheduled by the flow. + * + */ + object LeaveSubFlow : Event() { override fun toString() = "LeaveSubFlow" } + + /** + * Signal a flow suspension. This causes the flow's stack and the state machine's state together with the suspending + * IO request to be persisted into the database. + * + * Scheduled by the flow and executed inside the park closure. + * + * @param ioRequest the request triggering the suspension. + * @param maySkipCheckpoint indicates whether the persistence may be skipped. + * @param fiber the serialised stack of the flow. + */ + data class Suspend( + val ioRequest: FlowIORequest<*>, + val maySkipCheckpoint: Boolean, + val fiber: SerializedBytes> + ) : Event() { + override fun toString() = + "Suspend(" + + "ioRequest=$ioRequest, " + + "maySkipCheckpoint=$maySkipCheckpoint, " + + "fiber=${fiber.hash}, " + + ")" + } + + /** + * Signals clean flow finish. + * + * Scheduled by the flow. + * + * @param returnValue the return value of the flow. + */ + data class FlowFinish(val returnValue: Any?) : Event() +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt new file mode 100644 index 0000000000..40768c261e --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowFiber.kt @@ -0,0 +1,18 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.transitions.StateMachine + +/** + * An interface wrapping a fiber running a flow. + */ +interface FlowFiber { + val id: StateMachineRunId + val stateMachine: StateMachine + + @Suspendable + fun scheduleEvent(event: Event) + + fun snapshot(): StateMachineState +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt new file mode 100644 index 0000000000..20dda167c3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowHospital.kt @@ -0,0 +1,18 @@ +package net.corda.node.services.statemachine + +/** + * A flow hospital is a class that is notified when a flow transitions into an error state due to an uncaught exception + * or internal error condition, and when it becomes clean again (e.g. due to a resume). + * Also see [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor]. + */ +interface FlowHospital { + /** + * The flow running in [flowFiber] has errored. + */ + fun flowErrored(flowFiber: FlowFiber) + + /** + * The flow running in [flowFiber] has cleaned, possibly as a result of a flow hospital resume. + */ + fun flowCleaned(flowFiber: FlowFiber) +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt deleted file mode 100644 index bd29525072..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ /dev/null @@ -1,121 +0,0 @@ -package net.corda.node.services.statemachine - -import co.paralleluniverse.fibers.Suspendable -import net.corda.core.crypto.SecureHash -import java.time.Instant - -interface FlowIORequest { - // This is used to identify where we suspended, in case of message mismatch errors and other things where we - // don't have the original stack trace because it's in a suspended fiber. - val stackTraceInCaseOfProblems: StackSnapshot -} - -interface WaitingRequest : FlowIORequest { - fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean -} - -interface SessionedFlowIORequest : FlowIORequest { - val session: FlowSessionInternal -} - -interface SendRequest : SessionedFlowIORequest { - val message: SessionMessage -} - -interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { - val receiveType: Class - val userReceiveType: Class<*>? - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session -} - -data class SendAndReceive(override val session: FlowSessionInternal, - override val message: SessionMessage, - override val receiveType: Class, - override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -data class ReceiveOnly(override val session: FlowSessionInternal, - override val receiveType: Class, - override val userReceiveType: Class<*>?) : ReceiveRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -class ReceiveAll(val requests: List>) : WaitingRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() - - private fun isComplete(received: LinkedHashMap): Boolean { - return received.keys == requests.map { it.session }.toSet() - } - private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) } - - private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean { - return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd } - } - - @Suspendable - fun suspendAndExpectReceive(suspend: Suspend): Map { - val receivedMessages = LinkedHashMap() - - poll(receivedMessages) - return if (isComplete(receivedMessages)) { - receivedMessages - } else { - suspend(this) - poll(receivedMessages) - if (isComplete(receivedMessages)) { - receivedMessages - } else { - throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" }) - } - } - } - - interface Suspend { - @Suspendable - operator fun invoke(request: FlowIORequest) - } - - @Suspendable - private fun poll(receivedMessages: LinkedHashMap) { - return requests.filter { it.session !in receivedMessages.keys }.forEach { request -> - poll(request)?.let { - receivedMessages[request.session] = RequestMessage(request, it) - } - } - } - - @Suspendable - private fun poll(request: ReceiveRequest): ReceivedSessionMessage<*>? { - return request.session.receivedMessages.poll() - } - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant() - - private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session } - - data class RequestMessage(val request: ReceiveRequest, val message: ReceivedSessionMessage<*>) -} - -data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() - - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd -} - -data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { - @Transient - override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() -} - -class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt new file mode 100644 index 0000000000..8adf000fca --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -0,0 +1,75 @@ +package net.corda.node.services.statemachine + +import com.esotericsoftware.kryo.KryoException +import net.corda.core.flows.FlowException +import net.corda.core.identity.Party +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.serialize +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.ReceivedMessage +import java.io.NotSerializableException + +/** + * A wrapper interface around flow messaging. + */ +interface FlowMessaging { + /** + * Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to + * listen on the send acknowledgement. + */ + fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) + + /** + * Start the messaging using the [onMessage] message handler. + */ + fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) +} + +/** + * Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing. + */ +class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { + + companion object { + val log = contextLogger() + + val sessionTopic = "platform.session" + } + + override fun start(onMessage: (ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) -> Unit) { + serviceHub.networkService.addMessageHandler(sessionTopic) { receivedMessage, _, acknowledgeHandle -> + onMessage(receivedMessage, acknowledgeHandle) + } + } + + override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) { + log.trace { "Sending message $deduplicationId $message to party $party" } + val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId) + val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") + val address = serviceHub.networkService.getAddressOfParty(partyInfo) + val sequenceKey = when (message) { + is InitialSessionMessage -> message.initiatorSessionId + is ExistingSessionMessage -> message.recipientSessionId + } + serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey, acknowledgementHandler = acknowledgementHandler) + } + + private fun serializeSessionMessage(message: SessionMessage): SerializedBytes { + return try { + message.serialize() + } catch (exception: Exception) { + // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. + if ((exception is KryoException || exception is NotSerializableException) + && message is ExistingSessionMessage && message.payload is ErrorSessionMessage) { + val error = message.payload.flowException + val rewrappedError = FlowException(error?.message) + message.copy(payload = message.payload.copy(flowException = rewrappedError)).serialize() + } else { + throw exception + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt index 479bbe86da..b7758939b5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt @@ -1,20 +1,39 @@ package net.corda.node.services.statemachine +import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.FlowInfo -import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowSession import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.serialize +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.checkPayloadIs -class FlowSessionImpl(override val counterparty: Party) : FlowSession() { - internal lateinit var stateMachine: FlowStateMachine<*> - internal lateinit var sessionFlow: FlowLogic<*> +class FlowSessionImpl( + override val counterparty: Party, + val sourceSessionId: SessionId +) : FlowSession() { + + override fun toString() = "FlowSessionImpl(counterparty=$counterparty, sourceSessionId=$sourceSessionId)" + + override fun equals(other: Any?): Boolean { + return (other as? FlowSessionImpl)?.sourceSessionId == sourceSessionId + } + + override fun hashCode() = sourceSessionId.hashCode() + + private fun getFlowStateMachine(): FlowStateMachine<*> { + return Fiber.currentFiber() as FlowStateMachine<*> + } @Suspendable override fun getCounterpartyFlowInfo(maySkipCheckpoint: Boolean): FlowInfo { - return stateMachine.getFlowInfo(counterparty, sessionFlow, maySkipCheckpoint) + val request = FlowIORequest.GetFlowInfo(NonEmptySet.of(this)) + return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!! } @Suspendable @@ -26,14 +45,12 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { payload: Any, maySkipCheckpoint: Boolean ): UntrustworthyData { - return stateMachine.sendAndReceive( - receiveType, - counterparty, - payload, - sessionFlow, - retrySend = false, - maySkipCheckpoint = maySkipCheckpoint + enforceNotPrimitive(receiveType) + val request = FlowIORequest.SendAndReceive( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = false ) + return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType) } @Suspendable @@ -41,7 +58,9 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { @Suspendable override fun receive(receiveType: Class, maySkipCheckpoint: Boolean): UntrustworthyData { - return stateMachine.receive(receiveType, counterparty, sessionFlow, maySkipCheckpoint) + enforceNotPrimitive(receiveType) + val request = FlowIORequest.Receive(NonEmptySet.of(this)) + return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType) } @Suspendable @@ -49,12 +68,18 @@ class FlowSessionImpl(override val counterparty: Party) : FlowSession() { @Suspendable override fun send(payload: Any, maySkipCheckpoint: Boolean) { - return stateMachine.send(counterparty, payload, sessionFlow, maySkipCheckpoint) + val request = FlowIORequest.Send( + sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + shouldRetrySend = false + ) + return getFlowStateMachine().suspend(request, maySkipCheckpoint) } @Suspendable override fun send(payload: Any) = send(payload, maySkipCheckpoint = false) - override fun toString() = "Flow session with $counterparty" + private fun enforceNotPrimitive(type: Class<*>) { + require(!type.isPrimitive) { "Cannot receive primitive type $type" } + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt deleted file mode 100644 index dc5b39c6f5..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt +++ /dev/null @@ -1,56 +0,0 @@ -package net.corda.node.services.statemachine - -import net.corda.core.flows.FlowInfo -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSession -import net.corda.core.identity.Party -import net.corda.node.services.statemachine.FlowSessionState.Initiated -import net.corda.node.services.statemachine.FlowSessionState.Initiating -import java.util.concurrent.ConcurrentLinkedQueue - -/** - * @param retryable Indicates that the session initialisation should be retried until an expected [SessionData] response - * is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow - * that only sends back a single [SessionData] message before termination. - */ -// TODO rename this -class FlowSessionInternal( - val flow: FlowLogic<*>, - val flowSession : FlowSession, - val ourSessionId: Long, - val initiatingParty: Party?, - var state: FlowSessionState, - var retryable: Boolean = false) { - val receivedMessages = ConcurrentLinkedQueue>() - val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> - - override fun toString(): String { - return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)" - } -} - -/** - * [FlowSessionState] describes the session's state. - * - * [Uninitiated] is pre-handshake, where no communication has happened. [Initiating.otherParty] at this point holds a - * [Party] corresponding to either a specific peer or a service. - * [Initiating] is pre-handshake, where the initiating message has been sent. - * [Initiated] is post-handshake. At this point [Initiating.otherParty] will have been resolved to a specific peer - * [Initiated.peerParty], and the peer's sessionId has been initialised. - */ -sealed class FlowSessionState { - abstract val sendToParty: Party - - data class Uninitiated(val otherParty: Party) : FlowSessionState() { - override val sendToParty: Party get() = otherParty - } - - /** [otherParty] may be a specific peer or a service party */ - data class Initiating(val otherParty: Party) : FlowSessionState() { - override val sendToParty: Party get() = otherParty - } - - data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() { - override val sendToParty: Party get() = peerParty - } -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 2fb26c4f23..4b0408c622 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -5,249 +5,184 @@ import co.paralleluniverse.fibers.Fiber.parkAndSerialize import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand -import com.google.common.primitives.Primitives +import co.paralleluniverse.strands.channels.Channel import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.* -import net.corda.core.internal.concurrent.OpenFuture -import net.corda.core.internal.concurrent.openFuture -import net.corda.core.serialization.SerializationDefaults +import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.isRegularFile +import net.corda.core.internal.uncheckedCast +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.serialize -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.* +import net.corda.core.utilities.Try +import net.corda.core.utilities.debug +import net.corda.core.utilities.trace import net.corda.node.services.api.FlowAppAuditEvent import net.corda.node.services.api.FlowPermissionAuditEvent import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.logging.pushToLoggingContext -import net.corda.node.services.statemachine.FlowSessionState.Initiating +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import org.slf4j.Logger import org.slf4j.LoggerFactory import java.nio.file.Paths -import java.sql.SQLException -import java.time.Duration -import java.time.Instant -import java.util.* import java.util.concurrent.TimeUnit +import kotlin.reflect.KProperty1 class FlowPermissionException(message: String) : FlowException(message) +class TransientReference(@Transient val value: A) + class FlowStateMachineImpl(override val id: StateMachineRunId, override val logic: FlowLogic, - scheduler: FiberScheduler, - val ourIdentity: Party, - override val context: InvocationContext) : Fiber(id.toString(), scheduler), FlowStateMachine { - + scheduler: FiberScheduler + // Store the Party rather than the full cert path with PartyAndCertificate +) : Fiber(id.toString(), scheduler), FlowStateMachine, FlowFiber { companion object { - // Used to work around a small limitation in Quasar. - private val QUASAR_UNBLOCKER = Fiber::class.staticField("SERIALIZER_BLOCKER").value - /** * Return the current [FlowStateMachineImpl] or null if executing outside of one. */ fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*> + + private val log: Logger = LoggerFactory.getLogger("net.corda.flow") + + @Suspendable + private fun abortFiber(): Nothing { + Fiber.park() + throw IllegalStateException("Ended fiber unparked") + } + + private fun extractThreadLocalTransaction(): TransientReference { + val transaction = DatabaseTransactionManager.current() + DatabaseTransactionManager.setThreadLocalTx(null) + return TransientReference(transaction) + } } - // These fields shouldn't be serialised, so they are marked @Transient. - @Transient override lateinit var serviceHub: ServiceHubInternal - @Transient override lateinit var ourIdentityAndCert: PartyAndCertificate - @Transient internal lateinit var database: CordaPersistence - @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit - @Transient internal lateinit var actionOnEnd: (Try, Boolean) -> Unit - @Transient internal var fromCheckpoint: Boolean = false - @Transient private var txTrampoline: DatabaseTransaction? = null + override val serviceHub get() = getTransientField(TransientValues::serviceHub) + + data class TransientValues( + val eventQueue: Channel, + val resultFuture: CordaFuture, + val database: CordaPersistence, + val transitionExecutor: TransitionExecutor, + val actionExecutor: ActionExecutor, + val stateMachine: StateMachine, + val serviceHub: ServiceHubInternal, + val checkpointSerializationContext: SerializationContext + ) + + internal var transientValues: TransientReference? = null + internal var transientState: TransientReference? = null + + private fun getTransientField(field: KProperty1): A { + val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!") + return field.get(suppliedValues.value) + } /** * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message * is not necessary. */ - override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id") - @Transient private var resultFutureTransient: OpenFuture? = openFuture() - private val _resultFuture get() = resultFutureTransient ?: openFuture().also { resultFutureTransient = it } - /** This future will complete when the call method returns. */ - override val resultFuture: CordaFuture get() = _resultFuture - // This state IS serialised, as we need it to know what the fiber is waiting for. - internal val openSessions = HashMap, Party>, FlowSessionInternal>() - internal var waitingForResponse: WaitingRequest? = null + override val logger = log + override val resultFuture: CordaFuture get() = uncheckedCast(getTransientField(TransientValues::resultFuture)) + override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext + override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity internal var hasSoftLockedStates: Boolean = false set(value) { if (value) field = value else throw IllegalArgumentException("Can only set to true") } - init { - logic.stateMachine = this + @Suspendable + private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation { + val stateMachine = getTransientField(TransientValues::stateMachine) + val oldState = transientState!!.value + val actionExecutor = getTransientField(TransientValues::actionExecutor) + val transition = stateMachine.transition(event, oldState) + val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor) + transientState = TransientReference(newState) + return continuation + } + + @Suspendable + private fun processEventsUntilFlowIsResumed(): Any? { + val transitionExecutor = getTransientField(TransientValues::transitionExecutor) + val eventQueue = getTransientField(TransientValues::eventQueue) + eventLoop@while (true) { + val nextEvent = eventQueue.receive() + val continuation = processEvent(transitionExecutor, nextEvent) + when (continuation) { + is FlowContinuation.Resume -> return continuation.result + is FlowContinuation.Throw -> { + continuation.throwable.fillInStackTrace() + throw continuation.throwable + } + FlowContinuation.ProcessEvents -> continue@eventLoop + FlowContinuation.Abort -> abortFiber() + } + } } @Suspendable override fun run() { - createTransaction() + logic.stateMachine = this + + context.pushToLoggingContext() + + initialiseFlow() + logger.debug { "Calling flow: $logic" } val startTime = System.nanoTime() - val result = try { - val r = logic.call() - // Only sessions which have done a single send and nothing else will block here - openSessions.values - .filter { it.state is Initiating } - .forEach { it.waitForConfirmation() } - r - } catch (e: FlowException) { - recordDuration(startTime, success = false) - // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). - val propagated = e.stackTrace[0].className == javaClass.name - processException(e, propagated) - logger.warn(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e) - return - } catch (t: Throwable) { - recordDuration(startTime, success = false) - logger.warn("Terminated by unexpected exception", t) - processException(t, false) - return + val resultOrError = try { + val result = logic.call() + // TODO expose maySkipCheckpoint here + suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = false) + Try.Success(result) + } catch (throwable: Throwable) { + logger.warn("Flow threw exception", throwable) + Try.Failure(throwable) } - - recordDuration(startTime) - // This is to prevent actionOnEnd being called twice if it throws an exception - actionOnEnd(Try.Success(result), false) - _resultFuture.set(result) - logic.progressTracker?.currentStep = ProgressTracker.DONE - logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" } - } - - private fun createTransaction() { - // Make sure we have a database transaction - database.createTransaction() - logger.trace { "Starting database transaction ${DatabaseTransactionManager.currentOrNull()} on ${Strand.currentStrand()}" } - } - - private fun processException(exception: Throwable, propagated: Boolean) { - actionOnEnd(Try.Failure(exception), propagated) - _resultFuture.setException(exception) - logic.progressTracker?.endWithError(exception) - } - - internal fun commitTransaction() { - val transaction = DatabaseTransactionManager.current() - try { - logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." } - transaction.commit() - } catch (e: SQLException) { - // TODO: we will get here if the database is not available. Think about how to shutdown and restart cleanly. - logger.error("Transaction commit failed: ${e.message}", e) - System.exit(1) - } finally { - transaction.close() - } - } - - @Suspendable - override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession { - val sessionKey = Pair(sessionFlow, otherParty) - if (openSessions.containsKey(sessionKey)) { - throw IllegalStateException( - "Attempted to initiateFlow() twice in the same InitiatingFlow $sessionFlow for the same party " + - "$otherParty. This isn't supported in this version of Corda. Alternatively you may " + - "initiate a new flow by calling initiateFlow() in an " + - "@${InitiatingFlow::class.java.simpleName} sub-flow." - ) - } - val flowSession = FlowSessionImpl(otherParty) - createNewSession(otherParty, flowSession, sessionFlow) - flowSession.stateMachine = this - flowSession.sessionFlow = sessionFlow - return flowSession - } - - @Suspendable - override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): FlowInfo { - val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated - return state.context - } - - @Suspendable - override fun sendAndReceive(receiveType: Class, - otherParty: Party, - payload: Any, - sessionFlow: FlowLogic<*>, - retrySend: Boolean, - maySkipCheckpoint: Boolean): UntrustworthyData { - requireNonPrimitive(receiveType) - logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } - val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - val receivedSessionData: ReceivedSessionMessage = if (session == null) { - val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) - // Only do a receive here as the session init has carried the payload - receiveInternal(newSession, receiveType) - } else { - val sendData = createSessionData(session, payload) - sendAndReceiveInternal(session, sendData, receiveType) - } - logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" } - return receivedSessionData.checkPayloadIs(receiveType) - } - - @Suspendable - override fun receive(receiveType: Class, - otherParty: Party, - sessionFlow: FlowLogic<*>, - maySkipCheckpoint: Boolean): UntrustworthyData { - requireNonPrimitive(receiveType) - logger.debug { "receive(${receiveType.name}, $otherParty) ..." } - val session = getConfirmedSession(otherParty, sessionFlow) - val sessionData = receiveInternal(session, receiveType) - logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } - return sessionData.checkPayloadIs(receiveType) - } - - private fun requireNonPrimitive(receiveType: Class<*>) { - require(!receiveType.isPrimitive) { - "Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class" - } - } - - @Suspendable - override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean) { - logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" } - val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - if (session == null) { - // Don't send the payload again if it was already piggy-backed on a session init - initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false) - } else { - sendInternal(session, createSessionData(session, payload)) - } - } - - @Suspendable - override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>, maySkipCheckpoint: Boolean): SignedTransaction { - logger.debug { "waitForLedgerCommit($hash) ..." } - suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>)) - val stx = serviceHub.validatedTransactions.getTransaction(hash) - if (stx != null) { - logger.debug { "Transaction $hash committed to ledger" } - return stx - } - - // If the tx isn't committed then we may have been resumed due to an session ending in an error - for (session in openSessions.values) { - for (receivedMessage in session.receivedMessages) { - if (receivedMessage.message is ErrorSessionEnd) { - session.erroredEnd(receivedMessage.message) - } + val finalEvent = when (resultOrError) { + is Try.Success -> { + Event.FlowFinish(resultOrError.value) + } + is Try.Failure -> { + Event.Error(resultOrError.exception) } } - throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage") + processEvent(getTransientField(TransientValues::transitionExecutor), finalEvent) + processEventsUntilFlowIsResumed() + + recordDuration(startTime) } - // Provide a mechanism to sleep within a Strand without locking any transactional state. - // This checkpoints, since we cannot undo any database writes up to this point. @Suspendable - override fun sleepUntil(until: Instant) { - suspend(Sleep(until, this)) + private fun initialiseFlow() { + processEventsUntilFlowIsResumed() + } + + @Suspendable + override fun subFlow(subFlow: FlowLogic): R { + processEvent(getTransientField(TransientValues::transitionExecutor), Event.EnterSubFlow(subFlow.javaClass)) + return try { + subFlow.call() + } finally { + processEvent(getTransientField(TransientValues::transitionExecutor), Event.LeaveSubFlow) + } + } + + @Suspendable + override fun initiateFlow(party: Party): FlowSession { + val resume = processEvent( + getTransientField(TransientValues::transitionExecutor), + Event.InitiateFlow(party) + ) as FlowContinuation.Resume + return resume.result as FlowSession } // TODO Dummy implementation of access to application specific permission controls and audit logging @@ -292,231 +227,43 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> { - val requests = ArrayList>() - for ((session, receiveType) in sessions) { - val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow) - requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType)) - } - val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend) - val result = LinkedHashMap>() - for ((sessionInternal, requestAndMessage) in receivedMessages) { - val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request) - result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class) - } - return result - } - - internal fun pushToLoggingContext() = context.pushToLoggingContext() - - /** - * This method will suspend the state machine and wait for incoming session init response from other party. - */ - @Suspendable - private fun FlowSessionInternal.waitForConfirmation() { - val (peerParty, sessionInitResponse) = receiveInternal(this, null) - if (sessionInitResponse is SessionConfirm) { - state = FlowSessionState.Initiated( - peerParty, - sessionInitResponse.initiatedSessionId, - FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName)) - } else { - sessionInitResponse as SessionReject - throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") - } - } - - private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData { - val sessionState = session.state - val peerSessionId = when (sessionState) { - is FlowSessionState.Initiated -> sessionState.peerSessionId - else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session") - } - return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT)) - } - - @Suspendable - private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message)) - - private inline fun receiveInternal( - session: FlowSessionInternal, - userReceiveType: Class<*>?): ReceivedSessionMessage { - return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType)) - } - - private inline fun sendAndReceiveInternal( - session: FlowSessionInternal, - message: SessionMessage, - userReceiveType: Class<*>?): ReceivedSessionMessage { - return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType)) - } - - @Suspendable - private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? { - val session = openSessions[Pair(sessionFlow, otherParty)] ?: return null - return when (session.state) { - is FlowSessionState.Uninitiated -> null - is FlowSessionState.Initiating -> { - session.waitForConfirmation() - session - } - is FlowSessionState.Initiated -> session - } - } - - @Suspendable - private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal { - return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?: - initiateSession(otherParty, sessionFlow, null, waitForConfirmation = true) - } - - private fun createNewSession( - otherParty: Party, - flowSession: FlowSession, - sessionFlow: FlowLogic<*> - ) { - logger.trace { "Creating a new session with $otherParty" } - val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) - openSessions[Pair(sessionFlow, otherParty)] = session - } - - @Suspendable - private fun initiateSession( - otherParty: Party, - sessionFlow: FlowLogic<*>, - firstPayload: Any?, - waitForConfirmation: Boolean, - retryable: Boolean = false - ): FlowSessionInternal { - val session = openSessions[Pair(sessionFlow, otherParty)] ?: throw IllegalStateException("Expected an Uninitiated session for $otherParty") - val state = session.state as? FlowSessionState.Uninitiated ?: throw IllegalStateException("Tried to initiate a session $session, but it's already initiating/initiated") - logger.trace { "Initiating a new session with ${state.otherParty}" } - session.state = FlowSessionState.Initiating(state.otherParty) - session.retryable = retryable - val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass - val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT) - logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.") - val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) - sendInternal(session, sessionInit) - if (waitForConfirmation) { - session.waitForConfirmation() - } - return session - } - - @Suspendable - private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage { - return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) - } - - private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend { - @Suspendable - override fun invoke(request: FlowIORequest) { - suspend(request) - } - } - - @Suspendable - private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { - val polledMessage = session.receivedMessages.poll() - return if (polledMessage != null) { - if (this is SendAndReceive) { - // Since we've already received the message, we downgrade to a send only to get the payload out and not - // inadvertently block - suspend(SendOnly(session, message)) - } - polledMessage - } else { - // Suspend while we wait for a receive - suspend(this) - session.receivedMessages.poll() ?: - throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") - } - } - - private fun ReceivedSessionMessage<*>.confirmReceiveType( - receiveRequest: ReceiveRequest): ReceivedSessionMessage { - val session = receiveRequest.session - val receiveType = receiveRequest.receiveType - if (receiveType.isInstance(message)) { - return uncheckedCast(this) - } else if (message is SessionEnd) { - openSessions.values.remove(session) - if (message is ErrorSessionEnd) { - session.erroredEnd(message) - } else { - val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName - throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + - "sending a $expectedType") - } - } else { - throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest") - } - } - - private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing { - if (end.errorResponse != null) { - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - (end.errorResponse as java.lang.Throwable).fillInStackTrace() - throw end.errorResponse - } else { - throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") - } - } - - @Suspendable - private fun suspend(ioRequest: FlowIORequest) { - // We have to pass the thread local database transaction across via a transient field as the fiber park - // swaps them out. - txTrampoline = DatabaseTransactionManager.setThreadLocalTx(null) - if (ioRequest is WaitingRequest) - waitingForResponse = ioRequest - - var exceptionDuringSuspend: Throwable? = null + override fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): R { + val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) + val transaction = extractThreadLocalTransaction() + val transitionExecutor = TransientReference(getTransientField(TransientValues::transitionExecutor)) parkAndSerialize { _, _ -> logger.trace { "Suspended on $ioRequest" } - // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB - try { - DatabaseTransactionManager.setThreadLocalTx(txTrampoline) - txTrampoline = null - actionOnSuspend(ioRequest) - } catch (t: Throwable) { - // Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to - // resume the fiber just so that we can throw it when it's running. - exceptionDuringSuspend = t - logger.trace("Resuming so fiber can it terminate with the exception thrown during suspend process", t) - resume(scheduler) - } - } - if (exceptionDuringSuspend == null && ioRequest is Sleep) { - // Sleep on the fiber. This will not sleep if it's in the past. - Strand.sleep(Duration.between(Instant.now(), ioRequest.until).toNanos(), TimeUnit.NANOSECONDS) + DatabaseTransactionManager.setThreadLocalTx(transaction.value) + val event = try { + Event.Suspend( + ioRequest = ioRequest, + maySkipCheckpoint = maySkipCheckpoint, + fiber = this.serialize(context = serializationContext.value) + ) + } catch (throwable: Throwable) { + Event.Error(throwable) + } + + // We must commit the database transaction before returning from this closure, otherwise Quasar may schedule + // other fibers + require(processEvent(transitionExecutor.value, event) == FlowContinuation.ProcessEvents) + Fiber.unparkDeserialized(this, scheduler) } - createTransaction() - // TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate - // the fiber when exceptions occur inside a suspend. - exceptionDuringSuspend?.let { throw it } - logger.trace { "Resumed from $ioRequest" } + return processEventsUntilFlowIsResumed() as R } - internal fun resume(scheduler: FiberScheduler) { - try { - if (fromCheckpoint) { - logger.info("Resumed from checkpoint") - fromCheckpoint = false - Fiber.unparkDeserialized(this, scheduler) - } else if (state == State.NEW) { - logger.trace("Started") - start() - } else { - Fiber.unpark(this, QUASAR_UNBLOCKER) - } - } catch (t: Throwable) { - logger.error("Error during resume", t) - } + @Suspendable + override fun scheduleEvent(event: Event) { + getTransientField(TransientValues::eventQueue).send(event) } + override fun snapshot(): StateMachineState { + return transientState!!.value + } + + override val stateMachine get() = getTransientField(TransientValues::stateMachine) + /** * Records the duration of this flow – from call() to completion or failure. * Note that the duration will include the time the flow spent being parked, and not just the total diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt new file mode 100644 index 0000000000..973c18fa0b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/PropagatingFlowHospital.kt @@ -0,0 +1,21 @@ +package net.corda.node.services.statemachine + +import net.corda.core.utilities.debug +import net.corda.core.utilities.loggerFor + +/** + * A simple [FlowHospital] implementation that immediately triggers error propagation when a flow dirties. + */ +object PropagatingFlowHospital : FlowHospital { + private val log = loggerFor() + + override fun flowErrored(flowFiber: FlowFiber) { + log.debug { "Flow ${flowFiber.id} dirtied ${flowFiber.snapshot().checkpoint.errorState}" } + flowFiber.scheduleEvent(Event.StartErrorPropagation) + } + + override fun flowCleaned(flowFiber: FlowFiber) { + throw IllegalStateException("Flow ${flowFiber.id} cleaned after error propagation triggered") + } +} + diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt index c321d3768a..5481275f05 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -1,59 +1,122 @@ package net.corda.node.services.statemachine +import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowException -import net.corda.core.flows.UnexpectedFlowEndException -import net.corda.core.identity.Party -import net.corda.core.internal.castIfPossible +import net.corda.core.flows.FlowInfo import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.core.utilities.UntrustworthyData -import java.io.IOException +import java.security.SecureRandom + +/** + * A session between two flows is identified by two session IDs, the initiating and the initiated session ID. + * However after the session has been established the communication is symmetric. From then on we differentiate between + * the two session IDs with "source" ID (the ID from which we receive) and "sink" ID (the ID to which we send). + * + * Flow A (initiating) Flow B (initiated) + * initiatingId=sourceId=0 + * send(Initiate(initiatingId=0)) -----> initiatingId=sinkId=0 + * initiatedId=sourceId=1 + * initiatedId=sinkId=1 <----- send(Confirm(initiatedId=1)) + */ +@CordaSerializable +sealed class SessionMessage + @CordaSerializable -interface SessionMessage - -interface ExistingSessionMessage : SessionMessage { - val recipientSessionId: Long -} - -interface SessionInitResponse : ExistingSessionMessage { - val initiatorSessionId: Long - override val recipientSessionId: Long get() = initiatorSessionId -} - -interface SessionEnd : ExistingSessionMessage - -data class SessionInit(val initiatorSessionId: Long, - val initiatingFlowClass: String, - val flowVersion: Int, - val appName: String, - val firstPayload: SerializedBytes?) : SessionMessage - -data class SessionConfirm(override val initiatorSessionId: Long, - val initiatedSessionId: Long, - val flowVersion: Int, - val appName: String) : SessionInitResponse - -data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse - -data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes) : ExistingSessionMessage - -data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd - -data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd - -data class ReceivedSessionMessage(val sender: Party, val message: M) - -fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - val payloadData: T = try { - val serializer = SerializationDefaults.SERIALIZATION_FACTORY - serializer.deserialize(message.payload, type, SerializationDefaults.P2P_CONTEXT) - } catch (ex: Exception) { - throw IOException("Payload invalid", ex) +data class SessionId(val toLong: Long) { + companion object { + fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong()) } - return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: - throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + - "${payloadData.javaClass.name} (${payloadData})") - } + +/** + * The initial message to initiate a session with. + * + * @param initiatorSessionId the session ID of the initiator. On the sending side this is the *source* ID, on the + * receiving side this is the *sink* ID. + * @param initiationEntropy additional randomness to seed the initiated flow's deduplication ID. + * @param initiatorFlowClassName the class name to be used to determine the initiating-initiated mapping on the receiver + * side. + * @param flowVersion the version of the initiating flow. + * @param appName the name of the cordapp defining the initiating flow, or "corda" if it's a core flow. + * @param firstPayload the optional first payload. + */ +data class InitialSessionMessage( + val initiatorSessionId: SessionId, + val initiationEntropy: Long, + val initiatorFlowClassName: String, + val flowVersion: Int, + val appName: String, + val firstPayload: SerializedBytes? +) : SessionMessage() { + override fun toString() = "InitialSessionMessage(" + + "initiatorSessionId=$initiatorSessionId, " + + "initiationEntropy=$initiationEntropy, " + + "initiatorFlowClassName=$initiatorFlowClassName, " + + "appName=$appName, " + + "firstPayload=${firstPayload?.javaClass}" + + ")" +} + +/** + * A message sent when a session has been established already. + * + * @param recipientSessionId the recipient session ID. On the sending side this is the *sink* ID, on the receiving side + * this is the *source* ID. + * @param payload the rest of the message. + */ +data class ExistingSessionMessage( + val recipientSessionId: SessionId, + val payload: ExistingSessionMessagePayload +) : SessionMessage() + +/** + * The payload of an [ExistingSessionMessage] + */ +@CordaSerializable +sealed class ExistingSessionMessagePayload + +/** + * The confirmation message sent by the initiated side. + * @param initiatedSessionId the initiated session ID, the other half of [InitialSessionMessage.initiatorSessionId]. + * This is the *source* ID on the sending(initiated) side, and the *sink* ID on the receiving(initiating) side. + */ +data class ConfirmSessionMessage( + val initiatedSessionId: SessionId, + val initiatedFlowInfo: FlowInfo +) : ExistingSessionMessagePayload() + +/** + * A message containing flow-related data. + * + * @param payload the serialised payload. + */ +data class DataSessionMessage(val payload: SerializedBytes) : ExistingSessionMessagePayload() { + override fun toString() = "DataSessionMessage(payload=${payload.javaClass})" +} + +/** + * A message indicating that an error has happened. + * + * @param flowException the exception that happened. This is null if the error condition wasn't revealed to the + * receiving side. + * @param errorId the ID of the source error. This is always specified to allow posteriori correlation of error conditions. + */ +data class ErrorSessionMessage(val flowException: FlowException?, val errorId: Long) : ExistingSessionMessagePayload() + +/** + * A message indicating that a session initiation has failed. + * + * @param message a message describing the problem to the initator. + * @param errorId an error ID identifying this error condition. + */ +data class RejectSessionMessage(val message: String, val errorId: Long) : ExistingSessionMessagePayload() + +/** + * A message indicating that the flow hosting the session has ended. Note that this message is strictly part of the + * session protocol, the flow may be removed before all counter-flows have ended. + * + * The sole purpose of this message currently is to provide diagnostic in cases where the two communicating flows' + * protocols don't match up, e.g. one is waiting for the other, but the other side has already finished. + */ +object EndSessionMessage : ExistingSessionMessagePayload() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index fcddc980ea..2a1bc3b219 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -1,9 +1,11 @@ package net.corda.node.services.statemachine import net.corda.core.concurrent.CordaFuture -import net.corda.core.flows.FlowLogic -import net.corda.core.internal.FlowStateMachine import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine import net.corda.core.messaging.DataFeed import net.corda.core.utilities.Try import rx.Observable @@ -23,7 +25,6 @@ import rx.Observable * TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk * TODO: Timeouts * TODO: Surfacing of exceptions via an API and/or management UI - * TODO: Ability to control checkpointing explicitly, for cases where you know replaying a message can't hurt * TODO: Don't store all active flows in memory, load from the database on demand. */ interface StateMachineManager { @@ -43,7 +44,11 @@ interface StateMachineManager { * @param flowLogic The flow's code. * @param context The context of the flow. */ - fun startFlow(flowLogic: FlowLogic, context: InvocationContext): CordaFuture> + fun startFlow( + flowLogic: FlowLogic, + context: InvocationContext, + ourIdentity: Party? = null + ): CordaFuture> /** * Represents an addition/removal of a state machine. @@ -73,4 +78,13 @@ interface StateMachineManager { * Returns all currently live flows. */ val allStateMachines: List> -} \ No newline at end of file +} + +// These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call +// these functions again +interface StateMachineManagerInternal { + fun signalFlowHasStarted(flowId: StateMachineRunId) + fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) + fun removeSessionBindings(sessionIds: Set) + fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt index 500529aa5c..a281e7fc4d 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt @@ -4,56 +4,51 @@ import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberExecutorScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.instrument.SuspendableHelper -import co.paralleluniverse.strands.Strand +import co.paralleluniverse.strands.channels.Channels import com.codahale.metrics.Gauge -import com.esotericsoftware.kryo.KryoException -import com.google.common.collect.HashMultimap -import com.google.common.util.concurrent.MoreExecutors import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowException import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party -import net.corda.core.internal.* -import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.castIfPossible +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT -import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize +import net.corda.core.serialization.* +import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug -import net.corda.core.utilities.trace import net.corda.node.internal.InitiatedFlowFactory -import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.shouldCheckCheckpoints +import net.corda.node.services.messaging.AcknowledgeHandle import net.corda.node.services.messaging.ReceivedMessage -import net.corda.node.services.messaging.TopicSession +import net.corda.node.services.statemachine.interceptors.* +import net.corda.node.services.statemachine.transitions.StateMachine +import net.corda.node.services.statemachine.transitions.StateMachineConfiguration import net.corda.node.utilities.AffinityExecutor -import net.corda.node.utilities.newNamedSingleThreadExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence -import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit -import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl import net.corda.nodeapi.internal.serialization.withTokenContext import org.apache.activemq.artemis.utils.ReusableLatch -import org.slf4j.Logger import rx.Observable import rx.subjects.PublishSubject -import java.io.NotSerializableException +import java.security.SecureRandom import java.util.* import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.TimeUnit.SECONDS import javax.annotation.concurrent.ThreadSafe +import kotlin.collections.ArrayList +import kotlin.streams.toList /** * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which @@ -65,82 +60,43 @@ class StateMachineManagerImpl( val checkpointStorage: CheckpointStorage, val executor: AffinityExecutor, val database: CordaPersistence, + val secureRandom: SecureRandom, private val unfinishedFibers: ReusableLatch = ReusableLatch(), private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader -) : StateMachineManager { - inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) - +) : StateMachineManager, StateMachineManagerInternal { companion object { private val logger = contextLogger() - internal val sessionTopic = TopicSession("platform.session") - - init { - Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> - (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) - } - } } + private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture) + // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines // property. private class InnerState { - var started = false - val stateMachines = LinkedHashMap, Checkpoint>() val changesPublisher = PublishSubject.create()!! - val fibersWaitingForLedgerCommit = HashMultimap.create>()!! - - fun notifyChangeObservers(change: StateMachineManager.Change) { - changesPublisher.bufferUntilDatabaseCommit().onNext(change) - } + // True if we're shutting down, so don't resume anything. + var stopping = false + val flows = HashMap() + val startedFutures = HashMap>() } - private val scheduler = FiberScheduler() private val mutex = ThreadBox(InnerState()) - // This thread (only enabled in dev mode) deserialises checkpoints in the background to shake out bugs in checkpoint restore. - private val checkpointCheckerThread = if (serviceHub.configuration.shouldCheckCheckpoints()) { - newNamedSingleThreadExecutor("CheckpointChecker") - } else { - null - } - - @Volatile private var unrestorableCheckpoints = false - - // True if we're shutting down, so don't resume anything. - @Volatile private var stopping = false + private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor) // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. private val liveFibers = ReusableLatch() - // Monitoring support. private val metrics = serviceHub.monitoringService.metrics + private val sessionToFlow = ConcurrentHashMap() + private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub) + private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null + private val transitionExecutor = makeTransitionExecutor() - init { - metrics.register("Flows.InFlight", Gauge { mutex.content.stateMachines.size }) - } - - private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") - private val totalStartedFlows = metrics.counter("Flows.Started") - private val totalFinishedFlows = metrics.counter("Flows.Finished") - - private val openSessions = ConcurrentHashMap() - private val recentlyClosedSessions = ConcurrentHashMap() - - // Context for tokenized services in checkpoints - private lateinit var tokenizableServices: List - private val serializationContext by lazy { - SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) - } - - /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ - override fun > findStateMachines(flowClass: Class): List>> { - return mutex.locked { - stateMachines.keys.mapNotNull { - flowClass.castIfPossible(it.logic)?.let { it to uncheckedCast, FlowStateMachineImpl<*>>(it.stateMachine).resultFuture } - } - } - } + private var checkpointSerializationContext: SerializationContext? = null + private var tokenizableServices: List? = null + private var actionExecutor: ActionExecutor? = null override val allStateMachines: List> - get() = mutex.locked { stateMachines.keys.map { it.logic } } + get() = mutex.locked { flows.values.map { it.fiber.logic } } /** * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number @@ -148,46 +104,38 @@ class StateMachineManagerImpl( * * We use assignment here so that multiple subscribers share the same wrapped Observable. */ - override val changes: Observable = mutex.content.changesPublisher.wrapWithDatabaseTransaction() + override val changes: Observable = mutex.content.changesPublisher override fun start(tokenizableServices: List) { - this.tokenizableServices = tokenizableServices checkQuasarJavaAgentPresence() - restoreFibersFromCheckpoints() - listenToLedgerTransactions() - serviceHub.networkMapCache.nodeReady.then { executor.execute(this::resumeRestoredFibers) } - } - - private fun checkQuasarJavaAgentPresence() { - check(SuspendableHelper.isJavaAgentActive(), { - """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. - #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") - }) - } - - private fun listenToLedgerTransactions() { - // Observe the stream of committed, validated transactions and resume fibers that are waiting for them. - serviceHub.validatedTransactions.updates.subscribe { stx -> - val hash = stx.id - val fibers: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } - if (fibers.isNotEmpty()) { - executor.executeASAP { - for (fiber in fibers) { - fiber.logger.trace { "Transaction $hash has committed to the ledger, resuming" } - fiber.waitingForResponse = null - resumeFiber(fiber) - } + this.tokenizableServices = tokenizableServices + val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( + SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + ) + this.checkpointSerializationContext = checkpointSerializationContext + this.actionExecutor = makeActionExecutor(checkpointSerializationContext) + fiberDeserializationChecker?.start(checkpointSerializationContext) + val fibers = restoreFlowsFromCheckpoints() + metrics.register("Flows.InFlight", Gauge { mutex.content.flows.size }) + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) + } + serviceHub.networkMapCache.nodeReady.then { + resumeRestoredFlows(fibers) + flowMessaging.start { receivedMessage, acknowledgeHandle -> + executor.execute { + onSessionMessage(receivedMessage, acknowledgeHandle) } } } } - private fun decrementLiveFibers() { - liveFibers.countDown() - } - - private fun incrementLiveFibers() { - liveFibers.countUp() + override fun > findStateMachines(flowClass: Class): List>> { + return mutex.locked { + flows.values.mapNotNull { + flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture } + } + } } /** @@ -201,12 +149,17 @@ class StateMachineManagerImpl( mutex.locked { if (stopping) throw IllegalStateException("Already stopping!") stopping = true + for ((_, flow) in flows) { + flow.fiber.scheduleEvent(Event.SoftShutdown) + } } // Account for any expected Fibers in a test scenario. liveFibers.countDown(allowedUnsuspendedFiberCount) liveFibers.await() - checkpointCheckerThread?.let { MoreExecutors.shutdownAndAwaitTermination(it, 5, SECONDS) } - check(!unrestorableCheckpoints) { "Unrestorable checkpoints where created, please check the logs for details." } + fiberDeserializationChecker?.let { + val foundUnrestorableFibers = it.stop() + check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." } + } } /** @@ -215,253 +168,276 @@ class StateMachineManagerImpl( */ override fun track(): DataFeed>, StateMachineManager.Change> { return mutex.locked { - DataFeed(stateMachines.keys.map { it.logic }, changesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed()) } } - private fun restoreFibersFromCheckpoints() { - mutex.locked { - checkpointStorage.forEach { checkpoint -> - // If a flow is added before start() then don't attempt to restore it - if (!stateMachines.containsValue(checkpoint)) { - deserializeFiber(checkpoint, logger)?.let { - initFiber(it) - stateMachines[it] = checkpoint - } - } - true + override fun startFlow( + flowLogic: FlowLogic, + context: InvocationContext, + ourIdentity: Party? + ): CordaFuture> { + return startFlowInternal( + invocationContext = context, + flowLogic = flowLogic, + flowStart = FlowStart.Explicit, + ourIdentity = ourIdentity ?: getOurFirstIdentity(), + initialUnacknowledgedMessage = null, + isStartIdempotent = false + ) + } + + override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) { + val previousFlowId = sessionToFlow.put(sessionId, flowId) + if (previousFlowId != null) { + if (previousFlowId == flowId) { + logger.warn("Session binding from $sessionId to $flowId re-added") + } else { + throw IllegalStateException( + "Attempted to add session binding from session $sessionId to flow $flowId, " + + "however there was already a binding to $previousFlowId" + ) } } } - private fun resumeRestoredFibers() { - mutex.locked { - started = true - stateMachines.keys.forEach { resumeRestoredFiber(it) } + override fun removeSessionBindings(sessionIds: Set) { + val reRemovedSessionIds = HashSet() + for (sessionId in sessionIds) { + val flowId = sessionToFlow.remove(sessionId) + if (flowId == null) { + reRemovedSessionIds.add(sessionId) + } } - serviceHub.networkService.addMessageHandler(sessionTopic) { message, _ -> - executor.checkOnThread() - onSessionMessage(message) + if (reRemovedSessionIds.isNotEmpty()) { + logger.warn("Session binding from $reRemovedSessionIds re-removed") } } - private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { - fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } - val waitingForResponse = fiber.waitingForResponse - if (waitingForResponse != null) { - if (waitingForResponse is WaitForLedgerCommit) { - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(waitingForResponse.hash) - } - if (stx != null) { - fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed") - fiber.waitingForResponse = null - resumeFiber(fiber) - } else { - fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}") - mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) } + override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) { + mutex.locked { + val flow = flows.remove(flowId) + if (flow != null) { + decrementLiveFibers() + unfinishedFibers.countDown() + return when (removalReason) { + is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState) + is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState) + FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown) } } else { - fiber.logger.info("Restored, pending on receive") + logger.warn("Flow $flowId re-finished") } - } else { - resumeFiber(fiber) } } - private fun onSessionMessage(message: ReceivedMessage) { + override fun signalFlowHasStarted(flowId: StateMachineRunId) { + mutex.locked { + startedFutures.remove(flowId)?.set(Unit) + } + } + + private fun checkQuasarJavaAgentPresence() { + check(SuspendableHelper.isJavaAgentActive(), { + """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. + #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") + }) + } + + private fun decrementLiveFibers() { + liveFibers.countDown() + } + + private fun incrementLiveFibers() { + liveFibers.countUp() + } + + private fun restoreFlowsFromCheckpoints(): List { + return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) -> + // If a flow is added before start() then don't attempt to restore it + mutex.locked { if (flows.containsKey(id)) return@map null } + val checkpoint = deserializeCheckpoint(serializedCheckpoint) + if (checkpoint == null) return@map null + createFlowFromCheckpoint( + id = id, + checkpoint = checkpoint, + initialUnacknowledgedMessage = null, + isAnyCheckpointPersisted = true, + isStartIdempotent = false + ) + }.toList().filterNotNull() + } + + private fun resumeRestoredFlows(flows: List) { + for (flow in flows) { + addAndStartFlow(flow.fiber.id, flow) + } + } + + private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { val peer = message.peer val sessionMessage = try { message.data.deserialize() } catch (ex: Exception) { logger.error("Received corrupt SessionMessage data from $peer") + acknowledgeHandle.acknowledge() return } val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) if (sender != null) { when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) - is SessionInit -> onSessionInit(sessionMessage, message, sender) + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) } } else { logger.error("Unknown peer $peer in $sessionMessage") } } - private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { - val session = openSessions[message.recipientSessionId] - if (session != null) { - session.fiber.pushToLoggingContext() - session.fiber.logger.trace { "Received $message on $session from $sender" } - if (session.retryable) { - if (message is SessionConfirm && session.state is FlowSessionState.Initiated) { - session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" } - return - } - if (message !is SessionConfirm) { - serviceHub.networkService.cancelRedelivery(session.ourSessionId) - } - } - if (message is SessionEnd) { - openSessions.remove(message.recipientSessionId) - } - session.receivedMessages += ReceivedSessionMessage(sender, message) - if (resumeOnMessage(message, session)) { - // It's important that we reset here and not after the fiber's resumed, in case we receive another message - // before then. - session.fiber.waitingForResponse = null - updateCheckpoint(session.fiber) - session.fiber.logger.trace { "Resuming due to $message" } - resumeFiber(session.fiber) - } - } else { - val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) - if (peerParty != null) { - if (message is SessionConfirm) { - logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) + private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + try { + executor.checkOnThread() + val recipientId = sessionMessage.recipientSessionId + val flowId = sessionToFlow[recipientId] + if (flowId == null) { + if (sessionMessage.payload is EndSessionMessage) { + logger.debug { + "Got ${EndSessionMessage::class.java.simpleName} for " + + "unknown session $recipientId, discarding..." + } } else { - logger.trace { "Ignoring session end message for already closed session: $message" } + throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId") } } else { - logger.warn("Received a session message for unknown session: $message, from $sender") + val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") + flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) } + } catch (exception: Exception) { + logger.error("Exception while routing $sessionMessage", exception) + throw exception } } - // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger - // commit but a counterparty flow has ended with an error (in which case our flow also has to end) - private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { - val waitingForResponse = session.fiber.waitingForResponse - return waitingForResponse?.shouldResume(message, session) ?: false - } - - private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { - - logger.trace { "Received $sessionInit from $sender" } - val senderSessionId = sessionInit.initiatorSessionId - - fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) - - val (session, initiatedFlowFactory) = try { - val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) - val flowSession = FlowSessionImpl(sender) - val flow = initiatedFlowFactory.createFlow(flowSession) - val senderFlowVersion = when (initiatedFlowFactory) { - is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version - is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion + private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { + val errorId = secureRandom.nextLong() + val payload = RejectSessionMessage(message, errorId) + return ExistingSessionMessage(initiatorSessionId, payload) + } + val replyError = try { + val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage) + val initiatedSessionId = SessionId.createRandom(secureRandom) + val senderSession = FlowSessionImpl(sender, initiatedSessionId) + val flowLogic = initiatedFlowFactory.createFlow(senderSession) + val initiatedFlowInfo = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda") + is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName) } - val session = FlowSessionInternal( - flow, - flowSession, - random63BitValue(), - sender, - FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) - if (sessionInit.firstPayload != null) { - session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) + val senderCoreFlowVersion = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> senderPlatformVersion + is InitiatedFlowFactory.CorDapp -> null } - openSessions[session.ourSessionId] = session - val context = InvocationContext.peer(sender.name) - val fiber = createFiber(flow, context) - fiber.pushToLoggingContext() - logger.info("Accepting flow session from party ${sender.name}. Session id for tracing purposes is ${sessionInit.initiatorSessionId}.") - flowSession.sessionFlow = flow - flowSession.stateMachine = fiber - fiber.openSessions[Pair(flow, sender)] = session - updateCheckpoint(fiber) - session to initiatedFlowFactory - } catch (e: SessionRejectException) { - logger.warn("${e.logMessage}: $sessionInit") - sendSessionReject(e.rejectMessage) - return - } catch (e: Exception) { - logger.warn("Couldn't start flow session from $sessionInit", e) - sendSessionReject("Unable to establish session") - return + startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) + null + } catch (exception: Exception) { + logger.warn("Exception while creating initiated flow", exception) + createErrorMessage( + sessionMessage.initiatorSessionId, + (exception as? SessionRejectException)?.message ?: "Unable to establish session" + ) } - val (ourFlowVersion, appName) = when (initiatedFlowFactory) { - // The flow version for the core flows is the platform version - is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda" - is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName + if (replyError != null) { + flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom), null) + acknowledgeHandle.acknowledge() } - - sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) - session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } - session.fiber.logger.trace { "Initiated from $sessionInit on $session" } - resumeFiber(session.fiber) } - private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { + // TODO this is a temporary hack until we figure out multiple identities + private fun getOurFirstIdentity(): Party { + return serviceHub.myInfo.legalIdentities[0] + } + + private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> { val initiatingFlowClass = try { - Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java) + Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java) } catch (e: ClassNotFoundException) { - throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") + throw SessionRejectException("Don't know ${message.initiatorFlowClassName}") } catch (e: ClassCastException) { - throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") + throw SessionRejectException("${message.initiatorFlowClassName} is not a flow") } return serviceHub.getFlowFactory(initiatingFlowClass) ?: throw SessionRejectException("$initiatingFlowClass is not registered") } - private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { - return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) + private fun startInitiatedFlow( + flowLogic: FlowLogic, + triggeringUnacknowledgedMessage: AcknowledgeHandle, + peerSession: FlowSessionImpl, + initiatedSessionId: SessionId, + initiatingMessage: InitialSessionMessage, + senderCoreFlowVersion: Int?, + initiatedFlowInfo: FlowInfo + ) { + val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo) + val ourIdentity = getOurFirstIdentity() + startFlowInternal( + InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, + triggeringUnacknowledgedMessage, + isStartIdempotent = false + ) } - private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { - return try { - checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { - fromCheckpoint = true - } - } catch (t: Throwable) { - logger.error("Encountered unrestorable checkpoint!", t) - null + private fun startFlowInternal( + invocationContext: InvocationContext, + flowLogic: FlowLogic, + flowStart: FlowStart, + ourIdentity: Party, + initialUnacknowledgedMessage: AcknowledgeHandle?, + isStartIdempotent: Boolean + ): CordaFuture> { + val flowId = StateMachineRunId.createRandom() + val deduplicationSeed = when (flowStart) { + FlowStart.Explicit -> flowId.uuid.toString() + is FlowStart.Initiated -> + "${flowStart.initiatingMessage.initiatorSessionId.toLong}-" + + "${flowStart.initiatingMessage.initiationEntropy}" } - } - private fun createFiber(logic: FlowLogic, context: InvocationContext, ourIdentity: Party? = null): FlowStateMachineImpl { - val fsm = FlowStateMachineImpl( - StateMachineRunId.createRandom(), - logic, - scheduler, - ourIdentity ?: serviceHub.myInfo.legalIdentities[0], - context) - initFiber(fsm) - return fsm - } + // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties + // have access to the fiber (and thereby the service hub) + val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) + val resultFuture = openFuture() + flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) + flowLogic.stateMachine = flowStateMachineImpl + val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) - private fun initFiber(fiber: FlowStateMachineImpl<*>) { - verifyFlowLogicIsSuspendable(fiber.logic) - fiber.database = database - fiber.serviceHub = serviceHub - fiber.ourIdentityAndCert = serviceHub.myInfo.legalIdentitiesAndCerts.find { it.party == fiber.ourIdentity } - ?: throw IllegalStateException("Identity specified by ${fiber.id} (${fiber.ourIdentity.name}) is not one of ours!") - fiber.actionOnSuspend = { ioRequest -> - updateCheckpoint(fiber) - // We commit on the fibers transaction that was copied across ThreadLocals during suspend - // This will free up the ThreadLocal so on return the caller can carry on with other transactions - fiber.commitTransaction() - processIORequest(ioRequest) - decrementLiveFibers() - } - fiber.actionOnEnd = { result, propagated -> - try { - mutex.locked { - stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } - notifyChangeObservers(StateMachineManager.Change.Removed(fiber.logic, result)) - } - endAllFiberSessions(fiber, result, propagated) - } finally { - fiber.commitTransaction() - decrementLiveFibers() - totalFinishedFlows.inc() - unfinishedFibers.countDown() - } - } + val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow() + val startedFuture = openFuture() + val initialState = StateMachineState( + checkpoint = initialCheckpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = false, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = flowLogic + ) + flowStateMachineImpl.transientState = TransientReference(initialState) mutex.locked { - totalStartedFlows.inc() - unfinishedFibers.countUp() - notifyChangeObservers(StateMachineManager.Change.Add(fiber.logic)) + startedFutures[flowId] = startedFuture + } + addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture)) + return startedFuture.map { flowStateMachineImpl as FlowStateMachine } + } + + private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes): Checkpoint? { + return try { + serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) + } catch (exception: Throwable) { + logger.error("Encountered unrestorable checkpoint!", exception) + null } } @@ -478,170 +454,160 @@ class StateMachineManagerImpl( } } - private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, result: Try<*>, propagated: Boolean) { - openSessions.values.removeIf { session -> - if (session.fiber == fiber) { - session.endSession(fiber.context, (result as? Try.Failure)?.exception, propagated) - true - } else { - false - } - } + private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture): FlowStateMachineImpl.TransientValues { + return FlowStateMachineImpl.TransientValues( + eventQueue = Channels.newChannel(16, Channels.OverflowPolicy.BLOCK), + resultFuture = resultFuture, + database = database, + transitionExecutor = transitionExecutor, + actionExecutor = actionExecutor!!, + stateMachine = StateMachine(id, StateMachineConfiguration.default, secureRandom), + serviceHub = serviceHub, + checkpointSerializationContext = checkpointSerializationContext!! + ) } - private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) { - val initiatedState = state as? FlowSessionState.Initiated ?: return - val sessionEnd = if (exception == null) { - NormalSessionEnd(initiatedState.peerSessionId) - } else { - val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { - // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only - // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with. - exception - } else { - null + private fun createFlowFromCheckpoint( + id: StateMachineRunId, + checkpoint: Checkpoint, + isAnyCheckpointPersisted: Boolean, + isStartIdempotent: Boolean, + initialUnacknowledgedMessage: AcknowledgeHandle? + ): Flow { + val flowState = checkpoint.flowState + val resultFuture = openFuture() + val fiber = when (flowState) { + is FlowState.Unstarted -> { + val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = logic + ) + val fiber = FlowStateMachineImpl(id, logic, scheduler) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber + } + is FlowState.Started -> { + val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = fiber.logic + ) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber } - ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) } - sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) - recentlyClosedSessions[ourSessionId] = initiatedState.peerParty + + verifyFlowLogicIsSuspendable(fiber.logic) + + return Flow(fiber, resultFuture) } - /** - * Kicks off a brand new state machine of the given class. - * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is - * restarted with checkpointed state machines in the storage service. - * - * Note that you must be on the [executor] thread. - */ - override fun startFlow(flowLogic: FlowLogic, context: InvocationContext): CordaFuture> { - // TODO: Check that logic has @Suspendable on its call method. - executor.checkOnThread() - val fiber = database.transaction { - val fiber = createFiber(flowLogic, context) - updateCheckpoint(fiber) - fiber + private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) { + val checkpoint = flow.fiber.snapshot().checkpoint + for (sessionId in getFlowSessionIds(checkpoint)) { + sessionToFlow.put(sessionId, id) } - // If we are not started then our checkpoint will be picked up during start mutex.locked { - if (started) { - resumeFiber(fiber) - } - } - return doneFuture(fiber) - } - - private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) { - check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } - val newCheckpoint = Checkpoint(serializeFiber(fiber)) - val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) } - if (previousCheckpoint != null) { - checkpointStorage.removeCheckpoint(previousCheckpoint) - } - checkpointStorage.addCheckpoint(newCheckpoint) - checkpointingMeter.mark() - - checkpointCheckerThread?.execute { - // Immediately check that the checkpoint is valid by deserialising it. The idea is to plug any holes we have - // in our testing by failing any test where unrestorable checkpoints are created. - if (deserializeFiber(newCheckpoint, fiber.logger) == null) { - unrestorableCheckpoints = true - } - } - } - - private fun resumeFiber(fiber: FlowStateMachineImpl<*>) { - // Avoid race condition when setting stopping to true and then checking liveFibers - incrementLiveFibers() - if (!stopping) { - executor.executeASAP { - fiber.resume(scheduler) - } - } else { - fiber.logger.trace("Not resuming as SMM is stopping.") - decrementLiveFibers() - } - } - - private fun processIORequest(ioRequest: FlowIORequest) { - executor.checkOnThread() - when (ioRequest) { - is SendRequest -> processSendRequest(ioRequest) - is WaitForLedgerCommit -> processWaitForCommitRequest(ioRequest) - is Sleep -> processSleepRequest(ioRequest) - } - } - - private fun processSendRequest(ioRequest: SendRequest) { - val retryId = if (ioRequest.message is SessionInit) { - with(ioRequest.session) { - openSessions[ourSessionId] = this - if (retryable) ourSessionId else null - } - } else null - sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) - if (ioRequest !is ReceiveRequest<*>) { - // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - resumeFiber(ioRequest.session.fiber) - } - } - - private fun processWaitForCommitRequest(ioRequest: WaitForLedgerCommit) { - // Is it already committed? - val stx = database.transaction { - serviceHub.validatedTransactions.getTransaction(ioRequest.hash) - } - if (stx != null) { - resumeFiber(ioRequest.fiber) - } else { - // No, then register to wait. - // - // We assume this code runs on the server thread, which is the only place transactions are committed - // currently. When we liberalise our threading somewhat, handing of wait requests will need to be - // reworked to make the wait atomic in another way. Otherwise there is a race between checking the - // database and updating the waiting list. - mutex.locked { - fibersWaitingForLedgerCommit[ioRequest.hash] += ioRequest.fiber - } - } - } - - private fun processSleepRequest(ioRequest: Sleep) { - // Resume the fiber now we have checkpointed, so we can sleep on the Fiber. - resumeFiber(ioRequest.fiber) - } - - private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null, retryId: Long? = null) { - val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) - ?: throw IllegalArgumentException("Don't know about party $party") - val address = serviceHub.networkService.getAddressOfParty(partyInfo) - val logger = fiber?.logger ?: logger - logger.trace { "Sending $message to party $party @ $address" + if (retryId != null) " with retry $retryId" else "" } - - val serialized = try { - message.serialize() - } catch (e: Exception) { - when (e) { - // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. - is KryoException, - is NotSerializableException -> { - if (message !is ErrorSessionEnd || message.errorResponse == null) throw e - logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " + - "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) - // The subclass may have overridden toString so we use that - val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message } - message.copy(errorResponse = FlowException(exMessage)).serialize() + if (stopping) { + startedFutures[id]?.setException(IllegalStateException("Will not start flow as SMM is stopping")) + logger.trace("Not resuming as SMM is stopping.") + } else { + incrementLiveFibers() + unfinishedFibers.countUp() + flows.put(id, flow) + flow.fiber.scheduleEvent(Event.DoRemainingWork) + when (checkpoint.flowState) { + is FlowState.Unstarted -> { + flow.fiber.start() + } + is FlowState.Started -> { + Fiber.unparkDeserialized(flow.fiber, scheduler) + } } - else -> throw e + changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic)) } } + } - serviceHub.networkService.apply { - send(createMessage(sessionTopic, serialized.bytes), address, retryId = retryId) + private fun getFlowSessionIds(checkpoint: Checkpoint): Set { + val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated + return if (initiatedFlowStart == null) { + checkpoint.sessions.keys + } else { + checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId } } + + private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor { + return ActionExecutorImpl( + serviceHub, + checkpointStorage, + flowMessaging, + this, + checkpointSerializationContext + ) + } + + private fun makeTransitionExecutor(): TransitionExecutor { + val interceptors = ArrayList() + interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) } + if (serviceHub.configuration.devMode) { + interceptors.add { DumpHistoryOnErrorInterceptor(it) } + } + if (serviceHub.configuration.shouldCheckCheckpoints()) { + interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) } + } + if (logger.isDebugEnabled) { + interceptors.add { PrintingInterceptor(it) } + } + val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database) + return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) } + } + + private fun InnerState.removeFlowOrderly( + flow: Flow, + removalReason: FlowRemovalReason.OrderlyFinish, + lastState: StateMachineState + ) { + // final sanity checks + require(lastState.unacknowledgedMessages.isEmpty()) + require(lastState.isRemoved) + require(lastState.checkpoint.subFlowStack.size == 1) + sessionToFlow.none { it.value == flow.fiber.id } + flow.resultFuture.set(removalReason.flowReturnValue) + lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue))) + } + + private fun InnerState.removeFlowError( + flow: Flow, + removalReason: FlowRemovalReason.ErrorFinish, + lastState: StateMachineState + ) { + val flowError = removalReason.flowErrors[0] // TODO what to do with several? + val exception = flowError.exception + (exception as? FlowException)?.originalErrorId = flowError.errorId + flow.resultFuture.setException(exception) + lastState.flowLogic.progressTracker?.endWithError(exception) + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure(exception))) + } } -class SessionRejectException(val rejectMessage: String, val logMessage: String) : CordaException(rejectMessage) { - constructor(message: String) : this(message, message) -} +class SessionRejectException(reason: String) : CordaException(reason) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt new file mode 100644 index 0000000000..d6449865be --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -0,0 +1,232 @@ +package net.corda.node.services.statemachine + +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.Try +import net.corda.node.services.messaging.AcknowledgeHandle + +/** + * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is + * persisted to the database ([Checkpoint]), and the rest, which is an in-memory-only state. + * + * @param checkpoint the persisted part of the state. + * @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user. + * @param unacknowledgedMessages the list of currently unacknowledged messages. + * @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used + * to make [Event.DoRemainingWork] idempotent. + * @param isTransactionTracked true if a ledger transaction has been tracked as part of a + * [FlowIORequest.WaitForLedgerCommit]. This used is to make tracking idempotent. + * @param isAnyCheckpointPersisted true if at least a single checkpoint has been persisted. This is used to determine + * whether we should DELETE the checkpoint at the end of the flow. + * @param isStartIdempotent true if the start of the flow is idempotent, making the skipping of the initial checkpoint + * possible. + * @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further + * work. + */ +// TODO perhaps add a read-only environment to the state machine for things that don't change over time? +// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do. +data class StateMachineState( + val checkpoint: Checkpoint, + val flowLogic: FlowLogic<*>, + val unacknowledgedMessages: List, + val isFlowResumed: Boolean, + val isTransactionTracked: Boolean, + val isAnyCheckpointPersisted: Boolean, + val isStartIdempotent: Boolean, + val isRemoved: Boolean +) + +/** + * @param invocationContext the initiator of the flow. + * @param ourIdentity the identity the flow is run as. + * @param sessions map of source session ID to session state. + * @param subFlowStack the stack of currently executing subflows. + * @param flowState the state of the flow itself, including the frozen fiber/FlowLogic. + * @param errorState the "dirtiness" state including the involved errors and their propagation status. + * @param numberOfSuspends the number of flow suspends due to IO API calls. + * @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs. + */ +data class Checkpoint( + val invocationContext: InvocationContext, + val ourIdentity: Party, + val sessions: SessionMap, // This must preserve the insertion order! + val subFlowStack: List, + val flowState: FlowState, + val errorState: ErrorState, + val numberOfSuspends: Int, + val deduplicationSeed: String +) { + companion object { + + fun create( + invocationContext: InvocationContext, + flowStart: FlowStart, + flowLogicClass: Class>, + frozenFlowLogic: SerializedBytes>, + ourIdentity: Party, + deduplicationSeed: String + ): Try { + return SubFlow.create(flowLogicClass).map { topLevelSubFlow -> + Checkpoint( + invocationContext = invocationContext, + ourIdentity = ourIdentity, + sessions = emptyMap(), + subFlowStack = listOf(topLevelSubFlow), + flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), + errorState = ErrorState.Clean, + numberOfSuspends = 0, + deduplicationSeed = deduplicationSeed + ) + } + } + } +} + +/** + * The state of a session. + */ +sealed class SessionState { + + /** + * We haven't yet sent the initialisation message + */ + data class Uninitiated( + val party: Party, + val initiatingSubFlow: SubFlow.Initiating + ) : SessionState() + + /** + * We have sent the initialisation message but have not yet received a confirmation. + * @property rejectionError if non-null the initiation failed. + */ + data class Initiating( + val bufferedMessages: List>, + val rejectionError: FlowError? + ) : SessionState() + + /** + * We have received a confirmation, the peer party and session id is resolved. + * @property errors if not empty the session is in an errored state. + */ + data class Initiated( + val peerParty: Party, + val peerFlowInfo: FlowInfo, + val receivedMessages: List, + val initiatedState: InitiatedSessionState, + val errors: List + ) : SessionState() +} + +typealias SessionMap = Map + +/** + * Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest + * of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future + * [FlowInfo] requests. + */ +sealed class InitiatedSessionState { + data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState() + object Ended : InitiatedSessionState() { override fun toString() = "Ended" } +} + +/** + * Represents the way the flow has started. + */ +sealed class FlowStart { + /** + * The flow was started explicitly e.g. through RPC or a scheduled state. + */ + object Explicit : FlowStart() { override fun toString() = "Explicit" } + + /** + * The flow was started implicitly as part of session initiation. + */ + data class Initiated( + val peerSession: FlowSessionImpl, + val initiatedSessionId: SessionId, + val initiatingMessage: InitialSessionMessage, + val senderCoreFlowVersion: Int?, + val initiatedFlowInfo: FlowInfo + ) : FlowStart() { override fun toString() = "Initiated" } +} + +/** + * Represents the user-space related state of the flow. + */ +sealed class FlowState { + + /** + * The flow's unstarted state. We should always be able to start a fresh flow fiber from this datastructure. + * + * @param flowStart How the flow was started. + * @param frozenFlowLogic The serialized user-provided [FlowLogic]. + */ + data class Unstarted( + val flowStart: FlowStart, + val frozenFlowLogic: SerializedBytes> + ) : FlowState() { + override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash}" + } + + /** + * The flow's started state, this means the user-code has suspended on an IO request. + * + * @param flowIORequest what IO request the flow has suspended on. + * @param frozenFiber the serialized fiber itself. + */ + data class Started( + val flowIORequest: FlowIORequest<*>, + val frozenFiber: SerializedBytes> + ) : FlowState() { + override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash}" + } +} + +/** + * @param errorId the ID of the error. This is generated once for the source error and is propagated to neighbour + * sessions. + * @param exception the exception itself. Note that this may not contain information about the source error depending + * on whether the source error was a FlowException or otherwise. + */ +data class FlowError(val errorId: Long, val exception: Throwable) + +/** + * The flow's error state. + */ +sealed class ErrorState { + abstract fun addErrors(newErrors: List): ErrorState + + /** + * The flow is in a clean state. + */ + object Clean : ErrorState() { + override fun addErrors(newErrors: List): ErrorState { + return Errored(newErrors, 0, false) + } + override fun toString() = "Clean" + } + + /** + * The flow has dirtied because of an uncaught exception from user code or other error condition during a state + * transition. + * @param errors the list of errors. Multiple errors may be associated with the errored flow e.g. when multiple + * sessions are errored and have been waited on. + * @param propagatedIndex the index of the first error that hasn't yet been propagated. + * @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the + * sessions associated with the flow have been (or about to be) dirtied in counter-flows. + */ + data class Errored( + val errors: List, + val propagatedIndex: Int, + val propagating: Boolean + ) : ErrorState() { + override fun addErrors(newErrors: List): ErrorState { + return copy(errors = errors + newErrors) + } + } +} + diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt new file mode 100644 index 0000000000..e48d3acb21 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SubFlow.kt @@ -0,0 +1,74 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.utilities.Try + +/** + * A [SubFlow] contains metadata about a currently executing sub-flow. At any point the flow execution is + * characterised with a stack of [SubFlow]s. This stack is used to determine the initiating-initiated flow mapping. + * + * Note that Initiat*ed*ness is an orthogonal property of the top-level subflow, so we don't store any information about + * it here. + */ +sealed class SubFlow { + abstract val flowClass: Class> + + /** + * An inlined subflow. + */ + data class Inlined(override val flowClass: Class>) : SubFlow() + + /** + * An initiating subflow. + * @param [flowClass] the concrete class of the subflow. + * @param [classToInitiateWith] an ancestor class of [flowClass] with the [InitiatingFlow] annotation, to be sent + * to the initiated side. + * @param flowInfo the [FlowInfo] associated with the initiating flow. + */ + data class Initiating( + override val flowClass: Class>, + val classToInitiateWith: Class>, + val flowInfo: FlowInfo + ) : SubFlow() + + companion object { + fun create(flowClass: Class>): Try { + // Are we an InitiatingFlow? + val initiatingAnnotations = getInitiatingFlowAnnotations(flowClass) + return when (initiatingAnnotations.size) { + 0 -> { + Try.Success(Inlined(flowClass)) + } + 1 -> { + val initiatingAnnotation = initiatingAnnotations[0] + val flowContext = FlowInfo(initiatingAnnotation.second.version, flowClass.appName) + Try.Success(Initiating(flowClass, initiatingAnnotation.first, flowContext)) + } + else -> { + Try.Failure(IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated " + + "once, however the following classes all have the annotation: " + + "${initiatingAnnotations.map { it.first }}")) + } + } + } + + private fun getSuperClasses(clazz: Class): List> { + var currentClass: Class? = clazz + val result = ArrayList>() + while (currentClass != null) { + result.add(currentClass) + currentClass = currentClass.superclass + } + return result + } + + private fun getInitiatingFlowAnnotations(flowClass: Class>): List>, InitiatingFlow>> { + return getSuperClasses(flowClass).mapNotNull { clazz -> + val initiatingAnnotation = clazz.getDeclaredAnnotation(InitiatingFlow::class.java) + initiatingAnnotation?.let { Pair(clazz, it) } + } + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt new file mode 100644 index 0000000000..127cd5a286 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt @@ -0,0 +1,25 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult + +/** + * An executor of state machine transitions. This is mostly a wrapper interface around an [ActionExecutor], but can be + * used to create interceptors of transitions. + */ +interface TransitionExecutor { + @Suspendable + fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair +} + +/** + * An interceptor of a transition. These are currently explicitly hooked up in [StateMachineManagerImpl]. + */ +typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt new file mode 100644 index 0000000000..b895a40863 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt @@ -0,0 +1,66 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager +import java.security.SecureRandom + +/** + * This [TransitionExecutor] runs the transition actions using the passed in [ActionExecutor] and manually dirties the + * state on failure. + * + * If a failure happens when we're already transitioning into a errored state then the transition and the flow fiber is + * completely aborted to avoid error loops. + */ +class TransitionExecutorImpl( + val secureRandom: SecureRandom, + val database: CordaPersistence +) : TransitionExecutor { + private companion object { + val log = contextLogger() + } + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + DatabaseTransactionManager.dataSource = database + for (action in transition.actions) { + try { + actionExecutor.executeAction(fiber, action) + } catch (exception: Throwable) { + DatabaseTransactionManager.currentOrNull()?.close() + if (transition.newState.checkpoint.errorState is ErrorState.Errored) { + // If we errored while transitioning to an error state then we cannot record the additional + // error as that may result in an infinite loop, e.g. error propagation fails -> record error -> propagate fails again. + // Instead we just keep around the old error state and wait for a new schedule, perhaps + // triggered from a flow hospital + log.error("Error while executing $action during transition to errored state, aborting transition", exception) + return Pair(FlowContinuation.Abort, previousState.copy(isFlowResumed = false)) + } else { + // Otherwise error the state manually keeping the old flow state and schedule a DoRemainingWork + // to trigger error propagation + log.error("Error while executing $action, erroring state", exception) + val newState = previousState.copy( + checkpoint = previousState.checkpoint.copy( + errorState = previousState.checkpoint.errorState.addErrors( + listOf(FlowError(secureRandom.nextLong(), exception)) + ) + ), + isFlowResumed = false + ) + fiber.scheduleEvent(Event.DoRemainingWork) + return Pair(FlowContinuation.ProcessEvents, newState) + } + } + } + return Pair(transition.continuation, transition.newState) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt new file mode 100644 index 0000000000..e36bfcc516 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt @@ -0,0 +1,47 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.util.concurrent.ConcurrentHashMap + +/** + * This interceptor records a trace of all of the flows' states and transitions. If the flow dirties it dumps the trace + * transition to the logger. + */ +class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : TransitionExecutor { + companion object { + private val log = contextLogger() + } + + private val records = ConcurrentHashMap>() + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val transitionRecord = TransitionDiagnosticRecord(fiber.id, previousState, nextState, event, transition, continuation) + val record = records.compute(fiber.id) { _, record -> + (record ?: ArrayList()).apply { add(transitionRecord) } + } + + if (nextState.checkpoint.errorState is ErrorState.Errored) { + log.warn("Flow ${fiber.id} dirtied, dumping all transitions:\n${record!!.joinToString("\n")}") + } + + if (transition.newState.isRemoved) { + records.remove(fiber.id) + } + + return Pair(continuation, nextState) + } + +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt new file mode 100644 index 0000000000..cbde382f4d --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt @@ -0,0 +1,94 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.util.concurrent.LinkedBlockingQueue +import kotlin.concurrent.thread + +/** + * This interceptor checks whether a checkpointed fiber state can be deserialised in a separate thread. + */ +class FiberDeserializationCheckingInterceptor( + val fiberDeserializationChecker: FiberDeserializationChecker, + val delegate: TransitionExecutor +) : TransitionExecutor { + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val previousFlowState = previousState.checkpoint.flowState + val nextFlowState = nextState.checkpoint.flowState + if (nextFlowState is FlowState.Started) { + if (previousFlowState !is FlowState.Started || previousFlowState.frozenFiber != nextFlowState.frozenFiber) { + fiberDeserializationChecker.submitCheck(nextFlowState.frozenFiber) + } + } + return Pair(continuation, nextState) + } +} + +/** + * A fiber deserialisation checker thread. It checks the queued up serialised checkpoints to see if they can be + * deserialised. This is only run in development mode to allow detecting of corrupt serialised checkpoints before they + * are actually used. + */ +class FiberDeserializationChecker { + companion object { + val log = contextLogger() + } + + private sealed class Job { + class Check(val serializedFiber: SerializedBytes>) : Job() + object Finish : Job() + } + + private var checkerThread: Thread? = null + private val jobQueue = LinkedBlockingQueue() + private var foundUnrestorableFibers: Boolean = false + + fun start(checkpointSerializationContext: SerializationContext) { + require(checkerThread == null) + checkerThread = thread(name = "FiberDeserializationChecker") { + while (true) { + val job = jobQueue.take() + when (job) { + is Job.Check -> { + try { + job.serializedFiber.deserialize(context = checkpointSerializationContext) + } catch (throwable: Throwable) { + log.error("Encountered unrestorable checkpoint!", throwable) + foundUnrestorableFibers = true + } + } + Job.Finish -> { + return@thread + } + } + } + } + } + + fun submitCheck(serializedFiber: SerializedBytes>) { + jobQueue.add(Job.Check(serializedFiber)) + } + + /** + * Returns true if some unrestorable checkpoints were encountered, false otherwise + */ + fun stop(): Boolean { + jobQueue.add(Job.Finish) + checkerThread?.join() + return foundUnrestorableFibers + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt new file mode 100644 index 0000000000..8573f937e0 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt @@ -0,0 +1,43 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult +import java.util.concurrent.ConcurrentHashMap + +/** + * This interceptor notifies the passed in [flowHospital] in case a flow went through a clean->errored or a errored->clean + * transition. + */ +class HospitalisingInterceptor( + private val flowHospital: FlowHospital, + private val delegate: TransitionExecutor +) : TransitionExecutor { + private val hospitalisedFlows = ConcurrentHashMap() + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + when (nextState.checkpoint.errorState) { + ErrorState.Clean -> { + if (hospitalisedFlows.remove(fiber.id) != null) { + flowHospital.flowCleaned(fiber) + } + } + is ErrorState.Errored -> { + if (hospitalisedFlows.putIfAbsent(fiber.id, fiber) == null) { + flowHospital.flowErrored(fiber) + } + } + } + return Pair(continuation, nextState) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt new file mode 100644 index 0000000000..1f824aec70 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/PrintingInterceptor.kt @@ -0,0 +1,30 @@ +package net.corda.node.services.statemachine.interceptors + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.utilities.contextLogger +import net.corda.node.services.statemachine.* +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.transitions.TransitionResult + +/** + * This interceptor simply prints all state machine transitions. Useful for debugging. + */ +class PrintingInterceptor(val delegate: TransitionExecutor) : TransitionExecutor { + companion object { + val log = contextLogger() + } + + @Suspendable + override fun executeTransition( + fiber: FlowFiber, + previousState: StateMachineState, + event: Event, + transition: TransitionResult, + actionExecutor: ActionExecutor + ): Pair { + val (continuation, nextState) = delegate.executeTransition(fiber, previousState, event, transition, actionExecutor) + val transitionRecord = TransitionDiagnosticRecord(fiber.id, previousState, nextState, event, transition, continuation) + log.info("Transition for flow ${fiber.id} $transitionRecord") + return Pair(continuation, nextState) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt new file mode 100644 index 0000000000..96235b3eb2 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/TransitionDiagnosticRecord.kt @@ -0,0 +1,48 @@ +package net.corda.node.services.statemachine.interceptors + +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.transitions.FlowContinuation +import net.corda.node.services.statemachine.Event +import net.corda.node.services.statemachine.StateMachineState +import net.corda.node.services.statemachine.transitions.TransitionResult +import net.corda.node.utilities.ObjectDiffer + +/** + * This is a diagnostic record that stores information about a state machine transition and provides pretty printing + * by diffing the two states. + */ +data class TransitionDiagnosticRecord( + val flowId: StateMachineRunId, + val previousState: StateMachineState, + val nextState: StateMachineState, + val event: Event, + val transition: TransitionResult, + val continuation: FlowContinuation +) { + override fun toString(): String { + val diffIntended = ObjectDiffer.diff(previousState, transition.newState) + val diffNext = ObjectDiffer.diff(previousState, nextState) + return ( + listOf( + "", + " --- Transition of flow $flowId ---", + " Event: $event", + " Actions: ", + " ${transition.actions.joinToString("\n ")}", + " Continuation: ${transition.continuation}" + ) + + if (diffIntended != diffNext) { + listOf( + " Diff between previous and intended state:", + "${diffIntended?.toPaths()?.joinToString("")}" + ) + } else { + emptyList() + } + listOf( + + " Diff between previous and next state:", + "${diffNext?.toPaths()?.joinToString("")}" + ) + ).joinToString("\n") + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt new file mode 100644 index 0000000000..0796a5349a --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -0,0 +1,189 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.UnexpectedFlowEndException +import net.corda.node.services.statemachine.* + +/** + * This transition handles incoming session messages. It handles the following cases: + * - DataSessionMessage: these arrive to initiated and confirmed sessions and are expected to be received by the flow. + * - ConfirmSessionMessage: these arrive as a response to an InitialSessionMessage and include information about the + * counterparty flow's session ID as well as their [FlowInfo]. + * - ErrorSessionMessage: these arrive to initiated and confirmed sessions and put the corresponding session into an + * "errored" state. This means that whenever that session is subsequently interacted with the error will be thrown + * in the flow. + * - RejectSessionMessage: these arrive as a response to an InitialSessionMessage when the initiation failed. It + * behaves similarly to ErrorSessionMessage aside from the type of exceptions stored/raised. + * - EndSessionMessage: these are sent when the counterparty flow has finished. They put the corresponding session into + * an "ended" state. This means that subsequent sends on this session will fail, and receives will start failing + * after the buffer of already received messages is drained. + */ +class DeliverSessionMessageTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val event: Event.DeliverSessionMessage +) : Transition { + override fun transition(): TransitionResult { + return builder { + // Add the AcknowledgeHandle to the unacknowledged messages ASAP so in case an error happens we still know + // about the message. Note that in case of an error during deliver this message *will be acked*. + // For example if the session corresponding to the message is not found the message is still acked to free + // up the broker but the flow will error. + currentState = currentState.copy( + unacknowledgedMessages = currentState.unacknowledgedMessages + event.acknowledgeHandle + ) + // Check whether we have a session corresponding to the message. + val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId] + if (existingSession == null) { + freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId)) + } else { + val payload = event.sessionMessage.payload + // Dispatch based on what kind of message it is. + val _exhaustive = when (payload) { + is ConfirmSessionMessage -> confirmMessageTransition(existingSession, payload) + is DataSessionMessage -> dataMessageTransition(existingSession, payload) + is ErrorSessionMessage -> errorMessageTransition(existingSession, payload) + is RejectSessionMessage -> rejectMessageTransition(existingSession, payload) + is EndSessionMessage -> endMessageTransition() + } + } + if (!isErrored()) { + persistCheckpointIfNeeded() + } + // Schedule a DoRemainingWork to check whether the flow needs to be woken up. + actions.add(Action.ScheduleEvent(Event.DoRemainingWork)) + FlowContinuation.ProcessEvents + } + } + + private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) { + // We received a confirmation message. The corresponding session state must be Initiating. + when (sessionState) { + is SessionState.Initiating -> { + // Create the new session state that is now Initiated. + val initiatedSession = SessionState.Initiated( + peerParty = event.sender, + peerFlowInfo = message.initiatedFlowInfo, + receivedMessages = emptyList(), + initiatedState = InitiatedSessionState.Live(message.initiatedSessionId), + errors = emptyList() + ) + val newCheckpoint = currentState.checkpoint.copy( + sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession) + ) + // Send messages that were buffered pending confirmation of session. + val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) -> + val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage) + Action.SendExisting(initiatedSession.peerParty, existingMessage, deduplicationId) + } + actions.addAll(sendActions) + currentState = currentState.copy(checkpoint = newCheckpoint) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) { + // We received a data message. The corresponding session must be Initiated. + return when (sessionState) { + is SessionState.Initiated -> { + // Buffer the message in the session's receivedMessages buffer. + val newSessionState = sessionState.copy( + receivedMessages = sessionState.receivedMessages + message + ) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState) + ) + ) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) { + val exception: Throwable = if (payload.flowException == null) { + UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId) + } else { + payload.flowException.originalErrorId = payload.errorId + payload.flowException + } + + return when (sessionState) { + is SessionState.Initiated -> { + val checkpoint = currentState.checkpoint + val sessionId = event.sessionMessage.recipientSessionId + val flowError = FlowError(payload.errorId, exception) + val newSessionState = sessionState.copy(errors = sessionState.errors + flowError) + currentState = currentState.copy( + checkpoint = checkpoint.copy( + sessions = checkpoint.sessions + (sessionId to newSessionState) + ) + ) + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.rejectMessageTransition(sessionState: SessionState, payload: RejectSessionMessage) { + val exception = UnexpectedFlowEndException(payload.message, cause = null, originalErrorId = payload.errorId) + return when (sessionState) { + is SessionState.Initiating -> { + if (sessionState.rejectionError != null) { + // Double reject + freshErrorTransition(UnexpectedEventInState()) + } else { + val checkpoint = currentState.checkpoint + val sessionId = event.sessionMessage.recipientSessionId + val flowError = FlowError(payload.errorId, exception) + currentState = currentState.copy( + checkpoint = checkpoint.copy( + sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError)) + ) + ) + } + } + else -> freshErrorTransition(UnexpectedEventInState()) + } + } + + private fun TransitionBuilder.persistCheckpointIfNeeded() { + // We persist the message as soon as it arrives. + if (context.configuration.sessionDeliverPersistenceStrategy == SessionDeliverPersistenceStrategy.OnDeliver && + event.sessionMessage.payload !is EndSessionMessage) { + actions.addAll(arrayOf( + Action.CreateTransaction, + Action.PersistCheckpoint(context.id, currentState.checkpoint), + Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.unacknowledgedMessages) + )) + currentState = currentState.copy( + unacknowledgedMessages = emptyList(), + isAnyCheckpointPersisted = true + ) + } + } + + private fun TransitionBuilder.endMessageTransition() { + val sessionId = event.sessionMessage.recipientSessionId + val sessions = currentState.checkpoint.sessions + val sessionState = sessions[sessionId] + if (sessionState == null) { + return freshErrorTransition(CannotFindSessionException(sessionId)) + } + when (sessionState) { + is SessionState.Initiated -> { + val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = sessions + (sessionId to newSessionState) + ) + ) + } + else -> { + freshErrorTransition(UnexpectedEventInState()) + } + } + } + +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt new file mode 100644 index 0000000000..ced7867bc1 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt @@ -0,0 +1,37 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.node.services.statemachine.* + +/** + * This transition checks the current state of the flow and determines whether anything needs to be done. + */ +class DoRemainingWorkTransition( + override val context: TransitionContext, + override val startingState: StateMachineState +) : Transition { + override fun transition(): TransitionResult { + val checkpoint = startingState.checkpoint + // If the flow is removed or has been resumed don't do work. + if (startingState.isFlowResumed || startingState.isRemoved) { + return TransitionResult(startingState) + } + // Check whether the flow is errored + return when (checkpoint.errorState) { + is ErrorState.Clean -> cleanTransition() + is ErrorState.Errored -> erroredTransition(checkpoint.errorState) + } + } + + // If the flow is clean check the FlowState + private fun cleanTransition(): TransitionResult { + val checkpoint = startingState.checkpoint + return when (checkpoint.flowState) { + is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition() + is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition() + } + } + + private fun erroredTransition(errorState: ErrorState.Errored): TransitionResult { + return ErrorFlowTransition(context, startingState, errorState).transition() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt new file mode 100644 index 0000000000..70825807d0 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -0,0 +1,124 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowException +import net.corda.node.services.statemachine.* + +/** + * This transition defines what should happen when a flow has errored. + * + * In general there are two flow-level error conditions: + * + * - Internal exceptions. These may arise due to problems in the flow framework or errors during state machine + * transitions e.g. network or database failure. + * - User-raised exceptions. These are exceptions that are (re)raised in user code, allowing the user to catch them. + * These may come from illegal flow API calls, and FlowExceptions or other counterparty failures that are re-raised + * when the flow tries to use the corresponding sessions. + * + * Both internal exceptions and uncaught user-raised exceptions cause the flow to be errored. This flags the flow as + * unable to be resumed. When a flow is in this state an external source (e.g. Flow hospital) may decide to + * + * 1. Retry it (not implemented yet). This throws away the errored state and re-tries from the last clean checkpoint. + * 2. Start error propagation. This seals the flow as errored permanently and propagates the associated error(s) to + * all live sessions. This causes these sessions to errored on the other side, which may in turn cause the + * counter-flows themselves to errored. + * + * See [net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor] for how to detect flow errors. + * + * Note that in general we handle multiple errors at a time as several error conditions may arise at the same time and + * new errors may arise while the flow is in the errored state already. + */ +class ErrorFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + private val errorState: ErrorState.Errored +) : Transition { + override fun transition(): TransitionResult { + val allErrors: List = errorState.errors + val remainingErrorsToPropagate: List = allErrors.subList(errorState.propagatedIndex, allErrors.size) + val errorMessages: List = remainingErrorsToPropagate.map(this::createErrorMessageFromError) + + return builder { + // If we're errored and propagating do the actual propagation and update the index. + if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) { + val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages) + val newCheckpoint = startingState.checkpoint.copy( + errorState = errorState.copy(propagatedIndex = allErrors.size), + sessions = newSessions + ) + currentState = currentState.copy(checkpoint = newCheckpoint) + actions.add(Action.PropagateErrors(errorMessages, initiatedSessions)) + } + + // If we're errored but not propagating keep processing events. + if (remainingErrorsToPropagate.isNotEmpty() && !errorState.propagating) { + return@builder FlowContinuation.ProcessEvents + } + + // If we haven't been removed yet remove the flow. + if (!currentState.isRemoved) { + actions.add(Action.CreateTransaction) + if (currentState.isAnyCheckpointPersisted) { + actions.add(Action.RemoveCheckpoint(context.id)) + } + actions.addAll(arrayOf( + Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.unacknowledgedMessages), + Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) + )) + + currentState = currentState.copy( + unacknowledgedMessages = emptyList(), + isRemoved = true + ) + + val removalReason = FlowRemovalReason.ErrorFinish(allErrors) + actions.add(Action.RemoveFlow(context.id, removalReason, currentState)) + FlowContinuation.Abort + } else { + // Otherwise keep processing events. This branch happens when there are some outstanding initiating + // sessions that prevent the removal of the flow. + FlowContinuation.ProcessEvents + } + } + } + + private fun createErrorMessageFromError(error: FlowError): ErrorSessionMessage { + val exception = error.exception + // If the exception doesn't contain an originalErrorId that means it's a fresh FlowException that should + // propagate to the neighbouring flows. If it has the ID filled in that means it's a rethrown FlowException and + // shouldn't be propagated. + return if (exception is FlowException && exception.originalErrorId == null) { + ErrorSessionMessage(flowException = exception, errorId = error.errorId) + } else { + ErrorSessionMessage(flowException = null, errorId = error.errorId) + } + } + + // Buffer error messages in Initiating sessions, return the initialised ones. + private fun bufferErrorMessagesInInitiatingSessions( + sessions: Map, + errorMessages: List + ): Pair, Map> { + val newSessions = sessions.mapValues { (sourceSessionId, sessionState) -> + if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { + // *prepend* the error messages in order to error the other sessions ASAP. The other messages will + // be delivered all the same, they just won't trigger flow resumption because of dirtiness. + val errorMessagesWithDeduplication = errorMessages.map { + DeduplicationId.createForError(it.errorId, sourceSessionId) to it + } + sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) + } else { + sessionState + } + } + val initiatedSessions = sessions.values.mapNotNull { session -> + if (session is SessionState.Initiated && session.errors.isEmpty()) { + session + } else { + null + } + } + return Pair(initiatedSessions, newSessions) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt new file mode 100644 index 0000000000..0821773dd1 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -0,0 +1,399 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowSession +import net.corda.core.flows.UnexpectedFlowEndException +import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.toNonEmptySet +import net.corda.node.services.statemachine.* + +/** + * This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request + * is persisted (unless checkpoint was skipped) and the user-space DB transaction is commited. + * + * Before this transition we either did a checkpoint or the checkpoint was restored from the database. + */ +class StartedFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val started: FlowState.Started +) : Transition { + override fun transition(): TransitionResult { + val flowIORequest = started.flowIORequest + val checkpoint = startingState.checkpoint + val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint) + if (errorsToThrow.isNotEmpty()) { + return TransitionResult( + newState = startingState.copy(isFlowResumed = true), + // throw the first exception. TODO should this aggregate all of them somehow? + actions = listOf(Action.CreateTransaction), + continuation = FlowContinuation.Throw(errorsToThrow[0]) + ) + } + return when (flowIORequest) { + is FlowIORequest.Send -> sendTransition(flowIORequest) + is FlowIORequest.Receive -> receiveTransition(flowIORequest) + is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest) + is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest) + is FlowIORequest.Sleep -> sleepTransition(flowIORequest) + is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest) + is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition() + } + } + + private fun waitForSessionConfirmationsTransition(): TransitionResult { + return builder { + if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(Unit) + } + } + } + + private fun getFlowInfoTransition(flowIORequest: FlowIORequest.GetFlowInfo): TransitionResult { + val sessionIdToSession = LinkedHashMap() + for (session in flowIORequest.sessions) { + sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session + } + return builder { + // Initialise uninitialised sessions in order to receive the associated FlowInfo. Some or all sessions may + // not be initialised yet. + sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys) + val flowInfoMap = getFlowInfoFromSessions(sessionIdToSession) + if (flowInfoMap == null) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(flowInfoMap) + } + } + } + + private fun TransitionBuilder.getFlowInfoFromSessions(sessionIdToSession: Map): Map? { + val checkpoint = currentState.checkpoint + val resultMap = LinkedHashMap() + for ((sessionId, session) in sessionIdToSession) { + val sessionState = checkpoint.sessions[sessionId] + if (sessionState is SessionState.Initiated) { + resultMap[session] = sessionState.peerFlowInfo + } else { + return null + } + } + return resultMap + } + + private fun sleepTransition(flowIORequest: FlowIORequest.Sleep): TransitionResult { + return builder { + actions.add(Action.SleepUntil(flowIORequest.wakeUpAfter)) + resumeFlowLogic(Unit) + } + } + + private fun waitForLedgerCommitTransition(flowIORequest: FlowIORequest.WaitForLedgerCommit): TransitionResult { + return if (!startingState.isTransactionTracked) { + TransitionResult( + newState = startingState.copy(isTransactionTracked = true), + actions = listOf( + Action.CreateTransaction, + Action.TrackTransaction(flowIORequest.hash), + Action.CommitTransaction + ) + ) + } else { + TransitionResult(startingState) + } + } + + private fun sendAndReceiveTransition(flowIORequest: FlowIORequest.SendAndReceive): TransitionResult { + val sessionIdToMessage = LinkedHashMap>() + val sessionIdToSession = LinkedHashMap() + for ((session, message) in flowIORequest.sessionToMessage) { + val sessionId = (session as FlowSessionImpl).sourceSessionId + sessionIdToMessage[sessionId] = message + sessionIdToSession[sessionId] = session + } + return builder { + sendToSessionsTransition(sessionIdToMessage) + if (isErrored()) { + FlowContinuation.ProcessEvents + } else { + val receivedMap = receiveFromSessionsTransition(sessionIdToSession) + if (receivedMap == null) { + // We don't yet have the messages, change the suspension to be on Receive + val newIoRequest = FlowIORequest.Receive(flowIORequest.sessionToMessage.keys.toNonEmptySet()) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + flowState = FlowState.Started(newIoRequest, started.frozenFiber) + ) + ) + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(receivedMap) + } + } + } + } + + private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult { + return builder { + val sessionIdToSession = LinkedHashMap() + for (session in flowIORequest.sessions) { + sessionIdToSession[(session as FlowSessionImpl).sourceSessionId] = session + } + // send initialises to uninitialised sessions + sendInitialSessionMessagesIfNeeded(sessionIdToSession.keys) + val receivedMap = receiveFromSessionsTransition(sessionIdToSession) + if (receivedMap == null) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(receivedMap) + } + } + } + + private fun TransitionBuilder.receiveFromSessionsTransition( + sourceSessionIdToSessionMap: Map + ): Map>? { + val checkpoint = currentState.checkpoint + val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null + val resultMap = LinkedHashMap>() + for ((sessionId, message) in pollResult.messages) { + val session = sourceSessionIdToSessionMap[sessionId]!! + resultMap[session] = message + } + currentState = currentState.copy( + checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap) + ) + return resultMap + } + + data class PollResult( + val messages: Map>, + val newSessionMap: SessionMap + ) + private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set): PollResult? { + val newSessionMessages = LinkedHashMap(sessions) + val resultMessages = LinkedHashMap>() + var someNotFound = false + for (sessionId in sessionIds) { + val sessionState = sessions[sessionId] + when (sessionState) { + is SessionState.Initiated -> { + val messages = sessionState.receivedMessages + if (messages.isEmpty()) { + someNotFound = true + } else { + newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) + resultMessages[sessionId] = messages[0].payload + } + } + else -> { + someNotFound = true + } + } + } + return if (someNotFound) { + return null + } else { + PollResult(resultMessages, newSessionMessages) + } + } + + private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set) { + val checkpoint = startingState.checkpoint + val newSessions = LinkedHashMap(checkpoint.sessions) + var index = 0 + for (sourceSessionId in sourceSessions) { + val sessionState = checkpoint.sessions[sourceSessionId] + if (sessionState == null) { + return freshErrorTransition(CannotFindSessionException(sourceSessionId)) + } + if (sessionState !is SessionState.Uninitiated) { + continue + } + val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) + val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null) + actions.add(Action.SendInitial(sessionState.party, initialMessage, deduplicationId)) + newSessions[sourceSessionId] = SessionState.Initiating( + bufferedMessages = emptyList(), + rejectionError = null + ) + } + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + } + + private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult { + return builder { + val sessionIdToMessage = flowIORequest.sessionToMessage.mapKeys { + sessionToSessionId(it.key) + } + sendToSessionsTransition(sessionIdToMessage) + if (isErrored()) { + FlowContinuation.ProcessEvents + } else { + resumeFlowLogic(Unit) + } + } + } + + private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map>) { + val checkpoint = startingState.checkpoint + val newSessions = LinkedHashMap(checkpoint.sessions) + var index = 0 + for ((sourceSessionId, message) in sourceSessionIdToMessage) { + val existingSessionState = checkpoint.sessions[sourceSessionId] + if (existingSessionState == null) { + return freshErrorTransition(CannotFindSessionException(sourceSessionId)) + } else { + val sessionMessage = DataSessionMessage(message) + val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) + val _exhaustive = when (existingSessionState) { + is SessionState.Uninitiated -> { + val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message) + actions.add(Action.SendInitial(existingSessionState.party, initialMessage, deduplicationId)) + newSessions[sourceSessionId] = SessionState.Initiating( + bufferedMessages = emptyList(), + rejectionError = null + ) + Unit + } + is SessionState.Initiating -> { + // We're initiating this session, buffer the message + val newBufferedMessages = existingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage) + newSessions[sourceSessionId] = existingSessionState.copy(bufferedMessages = newBufferedMessages) + } + is SessionState.Initiated -> { + when (existingSessionState.initiatedState) { + is InitiatedSessionState.Live -> { + val sinkSessionId = existingSessionState.initiatedState.peerSinkSessionId + val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage) + actions.add(Action.SendExisting(existingSessionState.peerParty, existingMessage, deduplicationId)) + Unit + } + InitiatedSessionState.Ended -> { + return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId")) + } + } + } + } + } + + } + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + } + + private fun sessionToSessionId(session: FlowSession): SessionId { + return (session as FlowSessionImpl).sourceSessionId + } + + private fun collectErroredSessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.flatMap { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Uninitiated -> emptyList() + is SessionState.Initiating -> { + if (sessionState.rejectionError == null) { + emptyList() + } else { + listOf(sessionState.rejectionError.exception) + } + } + is SessionState.Initiated -> sessionState.errors.map(FlowError::exception) + } + } + } + + private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List { + return checkpoint.sessions.values.mapNotNull { sessionState -> + (sessionState as? SessionState.Initiating)?.rejectionError?.exception + } + } + + private fun collectEndedSessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.mapNotNull { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Initiated -> { + if (sessionState.initiatedState is InitiatedSessionState.Ended) { + UnexpectedFlowEndException( + "Tried to access ended session $sessionId", + cause = null, + originalErrorId = context.secureRandom.nextLong() + ) + } else { + null + } + } + else -> null + } + } + } + + private fun collectEndedEmptySessionErrors(sessionIds: Collection, checkpoint: Checkpoint): List { + return sessionIds.mapNotNull { sessionId -> + val sessionState = checkpoint.sessions[sessionId]!! + when (sessionState) { + is SessionState.Initiated -> { + if (sessionState.initiatedState is InitiatedSessionState.Ended && + sessionState.receivedMessages.isEmpty()) { + UnexpectedFlowEndException( + "Tried to access ended session $sessionId with empty buffer", + cause = null, + originalErrorId = context.secureRandom.nextLong() + ) + } else { + null + } + } + else -> null + } + } + } + + private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List { + return when (flowIORequest) { + is FlowIORequest.Send -> { + val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.Receive -> { + val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.SendAndReceive -> { + val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) + collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) + } + is FlowIORequest.WaitForLedgerCommit -> { + collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint) + } + is FlowIORequest.GetFlowInfo -> { + collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint) + } + is FlowIORequest.Sleep -> { + emptyList() + } + is FlowIORequest.WaitForSessionConfirmations -> { + collectErroredInitiatingSessionErrors(checkpoint) + } + } + } + + private fun createInitialSessionMessage( + initiatingSubFlow: SubFlow.Initiating, + sourceSessionId: SessionId, + payload: SerializedBytes? + ): InitialSessionMessage { + return InitialSessionMessage( + initiatorSessionId = sourceSessionId, + // We add additional entropy to add to the initiated side's deduplication seed. + initiationEntropy = context.secureRandom.nextLong(), + initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name, + flowVersion = initiatingSubFlow.flowInfo.flowVersion, + appName = initiatingSubFlow.flowInfo.appName, + firstPayload = payload + ) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt new file mode 100644 index 0000000000..8b37972423 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StateMachine.kt @@ -0,0 +1,30 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.* +import net.corda.node.services.statemachine.* +import java.security.SecureRandom + +enum class SessionDeliverPersistenceStrategy { + OnDeliver, + OnNextCommit +} + +data class StateMachineConfiguration( + val sessionDeliverPersistenceStrategy: SessionDeliverPersistenceStrategy +) { + companion object { + val default = StateMachineConfiguration( + sessionDeliverPersistenceStrategy = SessionDeliverPersistenceStrategy.OnDeliver + ) + } +} + +class StateMachine( + val id: StateMachineRunId, + val configuration: StateMachineConfiguration, + val secureRandom: SecureRandom +) { + fun transition(event: Event, state: StateMachineState): TransitionResult { + return TopLevelTransition(TransitionContext(id, configuration, secureRandom), state, event).transition() + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt new file mode 100644 index 0000000000..1366bb018f --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -0,0 +1,236 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.InitiatingFlow +import net.corda.core.internal.FlowIORequest +import net.corda.core.utilities.Try +import net.corda.node.services.statemachine.* + +/** + * This is the top level event-handling transition function capable of handling any [Event]. + * + * It is a *pure* function taking a state machine state and an event, returning the next state along with a list of IO + * actions to execute. + */ +class TopLevelTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val event: Event +) : Transition { + override fun transition(): TransitionResult { + return when (event) { + is Event.DoRemainingWork -> DoRemainingWorkTransition(context, startingState).transition() + is Event.DeliverSessionMessage -> DeliverSessionMessageTransition(context, startingState, event).transition() + is Event.Error -> errorTransition(event) + is Event.TransactionCommitted -> transactionCommittedTransition(event) + is Event.SoftShutdown -> softShutdownTransition() + is Event.StartErrorPropagation -> startErrorPropagationTransition() + is Event.EnterSubFlow -> enterSubFlowTransition(event) + is Event.LeaveSubFlow -> leaveSubFlowTransition() + is Event.Suspend -> suspendTransition(event) + is Event.FlowFinish -> flowFinishTransition(event) + is Event.InitiateFlow -> initiateFlowTransition(event) + } + } + + private fun errorTransition(event: Event.Error): TransitionResult { + return builder { + freshErrorTransition(event.exception) + FlowContinuation.ProcessEvents + } + } + + private fun transactionCommittedTransition(event: Event.TransactionCommitted): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + if (currentState.isTransactionTracked && + checkpoint.flowState is FlowState.Started && + checkpoint.flowState.flowIORequest is FlowIORequest.WaitForLedgerCommit && + checkpoint.flowState.flowIORequest.hash == event.transaction.id) { + currentState = currentState.copy(isTransactionTracked = false) + if (isErrored()) { + return@builder FlowContinuation.ProcessEvents + } + resumeFlowLogic(event.transaction) + } else { + freshErrorTransition(UnexpectedEventInState()) + FlowContinuation.ProcessEvents + } + } + } + + private fun softShutdownTransition(): TransitionResult { + val lastState = startingState.copy(isRemoved = true) + return TransitionResult( + newState = lastState, + actions = listOf( + Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys), + Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState) + ), + continuation = FlowContinuation.Abort + ) + } + + private fun startErrorPropagationTransition(): TransitionResult { + return builder { + val errorState = currentState.checkpoint.errorState + when (errorState) { + ErrorState.Clean -> freshErrorTransition(UnexpectedEventInState()) + is ErrorState.Errored -> { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + errorState = errorState.copy(propagating = true) + ) + ) + actions.add(Action.ScheduleEvent(Event.DoRemainingWork)) + } + } + FlowContinuation.ProcessEvents + } + } + + private fun enterSubFlowTransition(event: Event.EnterSubFlow): TransitionResult { + return builder { + val subFlow = SubFlow.create(event.subFlowClass) + when (subFlow) { + is Try.Success -> { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value + ) + ) + } + is Try.Failure -> { + freshErrorTransition(subFlow.exception) + } + } + FlowContinuation.ProcessEvents + } + } + + private fun leaveSubFlowTransition(): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + if (checkpoint.subFlowStack.isEmpty()) { + freshErrorTransition(UnexpectedEventInState()) + } else { + currentState = currentState.copy( + checkpoint = checkpoint.copy( + subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList() + ) + ) + } + FlowContinuation.ProcessEvents + } + } + + private fun suspendTransition(event: Event.Suspend): TransitionResult { + return builder { + val newCheckpoint = currentState.checkpoint.copy( + flowState = FlowState.Started(event.ioRequest, event.fiber), + numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1 + ) + if (event.maySkipCheckpoint) { + actions.addAll(arrayOf( + Action.CommitTransaction, + Action.ScheduleEvent(Event.DoRemainingWork) + )) + currentState = currentState.copy( + checkpoint = newCheckpoint, + isFlowResumed = false + ) + } else { + actions.addAll(arrayOf( + Action.PersistCheckpoint(context.id, newCheckpoint), + Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.unacknowledgedMessages), + Action.ScheduleEvent(Event.DoRemainingWork) + )) + currentState = currentState.copy( + checkpoint = newCheckpoint, + unacknowledgedMessages = emptyList(), + isFlowResumed = false, + isAnyCheckpointPersisted = true + ) + } + FlowContinuation.ProcessEvents + } + } + + private fun flowFinishTransition(event: Event.FlowFinish): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + when (checkpoint.errorState) { + ErrorState.Clean -> { + val unacknowledgedMessages = currentState.unacknowledgedMessages + currentState = currentState.copy( + checkpoint = checkpoint.copy( + numberOfSuspends = checkpoint.numberOfSuspends + 1 + ), + unacknowledgedMessages = emptyList(), + isFlowResumed = false, + isRemoved = true + ) + val allSourceSessionIds = checkpoint.sessions.keys + if (currentState.isAnyCheckpointPersisted) { + actions.add(Action.RemoveCheckpoint(context.id)) + } + actions.addAll(arrayOf( + Action.PersistDeduplicationIds(unacknowledgedMessages), + Action.CommitTransaction, + Action.AcknowledgeMessages(unacknowledgedMessages), + Action.RemoveSessionBindings(allSourceSessionIds), + Action.RemoveFlow(context.id, FlowRemovalReason.OrderlyFinish(event.returnValue), currentState) + )) + sendEndMessages() + // Resume to end fiber + FlowContinuation.Resume(null) + } + is ErrorState.Errored -> { + currentState = currentState.copy(isFlowResumed = false) + actions.add(Action.RollbackTransaction) + FlowContinuation.ProcessEvents + } + } + } + } + + private fun TransitionBuilder.sendEndMessages() { + val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state -> + if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { + val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) + val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index) + Action.SendExisting(state.peerParty, message, deduplicationId) + } else { + null + } + }.filterNotNull() + actions.addAll(sendEndMessageActions) + } + + private fun initiateFlowTransition(event: Event.InitiateFlow): TransitionResult { + return builder { + val checkpoint = currentState.checkpoint + val initiatingSubFlow = getClosestAncestorInitiatingSubFlow(checkpoint) + if (initiatingSubFlow == null) { + freshErrorTransition(IllegalStateException("Tried to initiate in a flow not annotated with @${InitiatingFlow::class.java.simpleName}")) + return@builder FlowContinuation.ProcessEvents + } + val sourceSessionId = SessionId.createRandom(context.secureRandom) + val sessionImpl = FlowSessionImpl(event.party, sourceSessionId) + val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow)) + currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + actions.add(Action.AddSessionBinding(context.id, sourceSessionId)) + FlowContinuation.Resume(sessionImpl) + } + } + + private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? { + for (subFlow in checkpoint.subFlowStack.asReversed()) { + if (subFlow is SubFlow.Initiating) { + return subFlow + } + } + return null + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt new file mode 100644 index 0000000000..20441dbab3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt @@ -0,0 +1,32 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.StateMachineRunId +import net.corda.node.services.statemachine.StateMachineState +import java.security.SecureRandom + +/** + * An interface used to separate out different parts of the state machine transition function. + */ +interface Transition { + /** The context of the transition. */ + val context: TransitionContext + /** The state the transition is starting in. */ + val startingState: StateMachineState + /** The (almost) pure transition function. The only side-effect we allow is random number generation. */ + fun transition(): TransitionResult + + /** + * A helper + */ + fun builder(build: TransitionBuilder.() -> FlowContinuation): TransitionResult { + val builder = TransitionBuilder(context, startingState) + val continuation = build(builder) + return TransitionResult(builder.currentState, builder.actions, continuation) + } +} + +class TransitionContext( + val id: StateMachineRunId, + val configuration: StateMachineConfiguration, + val secureRandom: SecureRandom +) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt new file mode 100644 index 0000000000..01715adde6 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionBuilder.kt @@ -0,0 +1,74 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.IdentifiableException +import net.corda.node.services.statemachine.* + +// This is a file defining some common utilities for creating state machine transitions. + +/** + * A builder that helps creating [Transition]s. This allows for a more imperative style of specifying the transition. + */ +class TransitionBuilder(val context: TransitionContext, initialState: StateMachineState) { + /** The current state machine state of the builder */ + var currentState = initialState + /** The list of actions to execute */ + val actions = ArrayList() + + /** Check if [currentState] state is errored */ + fun isErrored(): Boolean = currentState.checkpoint.errorState is ErrorState.Errored + + /** + * Transition the builder into an error state because of a fresh error that happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun freshErrorTransition(error: Throwable) { + val flowError = FlowError( + errorId = (error as? IdentifiableException)?.errorId ?: context.secureRandom.nextLong(), + exception = error + ) + errorTransition(flowError) + } + + /** + * Transition the builder into an error state because of a list of errors that happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun errorsTransition(errors: List) { + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + errorState = currentState.checkpoint.errorState.addErrors(errors) + ), + isFlowResumed = false + ) + actions.clear() + actions.addAll(arrayOf( + Action.RollbackTransaction, + Action.ScheduleEvent(Event.DoRemainingWork) + )) + } + + /** + * Transition the builder into an error state because of a non-fresh error has happened. + * Existing actions and the current state are thrown away, and the initial state is dirtied. + * + * @param error the error. + */ + fun errorTransition(error: FlowError) { + errorsTransition(listOf(error)) + } + + fun resumeFlowLogic(result: Any?): FlowContinuation { + actions.add(Action.CreateTransaction) + currentState = currentState.copy(isFlowResumed = true) + return FlowContinuation.Resume(result) + } +} + + + +class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId") +class UnexpectedEventInState : IllegalStateException("Unexpected event") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt new file mode 100644 index 0000000000..43e934634b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TransitionResult.kt @@ -0,0 +1,46 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.node.services.statemachine.Action +import net.corda.node.services.statemachine.StateMachineState + +/** + * A datastructure capturing the intended new state of the flow, the actions to be executed as part of the transition + * and a [FlowContinuation]. + * + * Read this datastructure as an instruction to the state machine executor: + * "Transition to [newState] *if* [actions] execute cleanly. If so, use [continuation] to decide what to do next. If + * there was an error it's up to you what to do". + * Also see [net.corda.node.services.statemachine.TransitionExecutorImpl] on how this is interpreted. + */ +data class TransitionResult( + val newState: StateMachineState, + val actions: List = emptyList(), + val continuation: FlowContinuation = FlowContinuation.ProcessEvents +) + +/** + * A datastructure describing what to do after a transition has succeeded. + */ +sealed class FlowContinuation { + /** + * Return to user code with the supplied [result]. + */ + data class Resume(val result: Any?) : FlowContinuation() { + override fun toString() = "Resume(result=${result?.javaClass})" + } + + /** + * Throw an exception [throwable] in user code. + */ + data class Throw(val throwable: Throwable) : FlowContinuation() + + /** + * Keep processing pending events. + */ + object ProcessEvents : FlowContinuation() { override fun toString() = "ProcessEvents" } + + /** + * Immediately abort the flow. Note that this does not imply an error condition. + */ + object Abort : FlowContinuation() { override fun toString() = "Abort" } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt new file mode 100644 index 0000000000..9fb6b31107 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -0,0 +1,80 @@ +package net.corda.node.services.statemachine.transitions + +import net.corda.core.flows.FlowInfo +import net.corda.node.services.statemachine.* + +/** + * This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and + * initialises the initiated session in case the flow is an initiated one. + */ +class UnstartedFlowTransition( + override val context: TransitionContext, + override val startingState: StateMachineState, + val unstarted: FlowState.Unstarted +) : Transition { + override fun transition(): TransitionResult { + return builder { + if (!currentState.isAnyCheckpointPersisted && !currentState.isStartIdempotent) { + createInitialCheckpoint() + } + + actions.add(Action.SignalFlowHasStarted(context.id)) + + if (unstarted.flowStart is FlowStart.Initiated) { + initialiseInitiatedSession(unstarted.flowStart) + } + + currentState = currentState.copy(isFlowResumed = true) + actions.add(Action.CreateTransaction) + FlowContinuation.Resume(null) + } + } + + // Initialise initiated session, store initial payload, send confirmation back. + private fun TransitionBuilder.initialiseInitiatedSession(flowStart: FlowStart.Initiated) { + val initiatingMessage = flowStart.initiatingMessage + val initiatedState = SessionState.Initiated( + peerParty = flowStart.peerSession.counterparty, + initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId), + peerFlowInfo = FlowInfo( + flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion, + appName = initiatingMessage.appName + ), + receivedMessages = if (initiatingMessage.firstPayload == null) { + emptyList() + } else { + listOf(DataSessionMessage(initiatingMessage.firstPayload)) + }, + errors = emptyList() + ) + val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) + val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) + currentState = currentState.copy( + checkpoint = currentState.checkpoint.copy( + sessions = mapOf(flowStart.initiatedSessionId to initiatedState) + ) + ) + actions.add( + Action.SendExisting( + flowStart.peerSession.counterparty, + sessionMessage, + DeduplicationId.createForNormal(currentState.checkpoint, 0) + ) + ) + } + + // Create initial checkpoint and acknowledge triggering messages. + private fun TransitionBuilder.createInitialCheckpoint() { + actions.addAll(arrayOf( + Action.CreateTransaction, + Action.PersistCheckpoint(context.id, currentState.checkpoint), + Action.PersistDeduplicationIds(currentState.unacknowledgedMessages), + Action.CommitTransaction, + Action.AcknowledgeMessages(currentState.unacknowledgedMessages) + )) + currentState = currentState.copy( + unacknowledgedMessages = emptyList(), + isAnyCheckpointPersisted = true + ) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt index b57ac3a45c..7e13cfab85 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt @@ -4,9 +4,13 @@ import net.corda.core.contracts.FungibleAsset import net.corda.core.contracts.StateRef import net.corda.core.flows.FlowLogic import net.corda.core.node.services.VaultService -import net.corda.core.utilities.* +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.toNonEmptySet +import net.corda.core.utilities.trace import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager +import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import java.util.* class VaultSoftLockManager private constructor(private val vault: VaultService) { @@ -48,11 +52,15 @@ class VaultSoftLockManager private constructor(private val vault: VaultService) private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet) { log.trace { "Reserving soft locks for flow id $flowId and states $stateRefs" } - vault.softLockReserve(flowId, stateRefs) + DatabaseTransactionManager.dataSource.transaction { + vault.softLockReserve(flowId, stateRefs) + } } private fun unregisterSoftLocks(flowId: UUID, logic: FlowLogic<*>) { log.trace { "Releasing soft locks for flow ${logic.javaClass.simpleName} with flow id $flowId" } - vault.softLockRelease(flowId) + DatabaseTransactionManager.dataSource.transaction { + vault.softLockRelease(flowId) + } } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt b/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt new file mode 100644 index 0000000000..3f0d73d2ed --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/ObjectDiffer.kt @@ -0,0 +1,144 @@ +package net.corda.node.utilities + +import java.lang.reflect.Method +import java.lang.reflect.Modifier +import java.lang.reflect.Type +import java.time.Instant + +/** + * A tree describing the diff between two objects. + * + * For example: + * data class A(val field1: Int, val field2: String, val field3: Unit) + * fun main(args: Array) { + * val someA = A(1, "hello", Unit) + * val someOtherA = A(2, "bello", Unit) + * println(ObjectDiffer.diff(someA, someOtherA)) + * } + * + * Will give back Step(branches=[(field1, Last(a=1, b=2)), (field2, Last(a=hello, b=bello))]) + */ +sealed class DiffTree { + /** + * Describes a "step" from the object root. It contains a list of field-subtree pairs. + */ + data class Step(val branches: List>) : DiffTree() + + /** + * Describes the leaf of the diff. This is either where the diffing was cutoff (e.g. primitives) or where it failed. + */ + data class Last(val a: Any?, val b: Any?) : DiffTree() + + /** + * Flattens the [DiffTree] into a list of [DiffPath]s + */ + fun toPaths(): List { + return when (this) { + is Step -> branches.flatMap { (step, tree) -> tree.toPaths().map { it.copy(path = listOf(step) + it.path) } } + is Last -> listOf(DiffPath(emptyList(), a, b)) + } + } +} + +/** + * A diff focused on a single [DiffTree.Last] diff, including the path leading there. + */ +data class DiffPath( + val path: List, + val a: Any?, + val b: Any? +) { + override fun toString(): String { + return "${path.joinToString(".")}: \n $a\n $b\n" + } +} + +/** + * This is a very simple differ used to diff objects of any kind, to be used for diagnostic. + */ +object ObjectDiffer { + fun diff(a: Any?, b: Any?): DiffTree? { + if (a == null || b == null) { + if (a == b) { + return null + } else { + return DiffTree.Last(a, b) + } + } + if (a != b) { + if (a.javaClass.isPrimitive || a.javaClass in diffCutoffClasses) { + return DiffTree.Last(a, b) + } + // TODO deduplicate this code + if (a is Map<*, *> && b is Map<*, *>) { + val allKeys = a.keys + b.keys + val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } } + if (branches.isEmpty()) { + return null + } else { + return DiffTree.Step(branches) + } + } + if (a is java.util.Map<*, *> && b is java.util.Map<*, *>) { + val allKeys = a.keySet() + b.keySet() + val branches = allKeys.mapNotNull { key -> diff(a.get(key), b.get(key))?.let { key.toString() to it } } + if (branches.isEmpty()) { + return null + } else { + return DiffTree.Step(branches) + } + } + val aFields = getFieldFoci(a) + val bFields = getFieldFoci(b) + try { + if (aFields != bFields) { + return DiffTree.Last(a, b) + } else { + // TODO need to account for cases where the fields don't match up (different subclasses) + val branches = aFields.map { field -> diff(field.get(a), field.get(b))?.let { field.name to it } }.filterNotNull() + if (branches.isEmpty()) { + return DiffTree.Last(a, b) + } else { + return DiffTree.Step(branches) + } + } + } catch (throwable: Exception) { + Exception("Error while diffing $a with $b", throwable).printStackTrace(System.out) + return DiffTree.Last(a, b) + } + } else { + return null + } + } + + // List of types to cutoff the diffing at. + private val diffCutoffClasses: Set> = setOf( + String::class.java, + Class::class.java, + Instant::class.java + ) + + // A type capturing the accessor to a field. This is a separate abstraction to simple reflection as we identify + // getX() and isX() calls as fields as well. + private data class FieldFocus(val name: String, val type: Type, val getter: Method) { + fun get(obj: Any): Any? { + return getter.invoke(obj) + } + } + + private fun getFieldFoci(obj: Any) : List { + val foci = ArrayList() + for (method in obj.javaClass.declaredMethods) { + if (Modifier.isStatic(method.modifiers)) { + continue + } + if (method.name.startsWith("get") && method.name.length > 3 && method.parameterCount == 0) { + val fieldName = method.name[3].toLowerCase() + method.name.substring(4) + foci.add(FieldFocus(fieldName, method.returnType, method)) + } else if (method.name.startsWith("is") && method.parameterCount == 0) { + foci.add(FieldFocus(method.name, method.returnType, method)) + } + } + return foci + } +} diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index 18ed176857..1965e3dc12 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -2,7 +2,6 @@ package net.corda.node.messaging import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.TopicStringValidator -import net.corda.node.services.messaging.createMessage import net.corda.testing.node.MockNetwork import org.junit.After import org.junit.Before @@ -48,10 +47,10 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var finalDelivery: Message? = null - node2.network.addMessageHandler { msg, _ -> + node2.network.addMessageHandler("test.topic") { msg, _, _ -> node2.network.send(msg, node3.network.myAddress) } - node3.network.addMessageHandler { msg, _ -> + node3.network.addMessageHandler("test.topic") { msg, _, _ -> finalDelivery = msg } @@ -60,7 +59,7 @@ class InMemoryMessagingTests { mockNet.runNetwork(rounds = 1) - assertTrue(Arrays.equals(finalDelivery!!.data, bits)) + assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits)) } @Test @@ -72,7 +71,7 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var counter = 0 - listOf(node1, node2, node3).forEach { it.network.addMessageHandler { _, _ -> counter++ } } + listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } } node1.network.send(node2.network.createMessage("test.topic", data = bits), mockNet.messagingNetwork.everyoneOnline) mockNet.runNetwork(rounds = 1) assertEquals(3, counter) @@ -88,12 +87,12 @@ class InMemoryMessagingTests { val node2 = mockNet.createNode() var received = 0 - node1.network.addMessageHandler("valid_message") { _, _ -> + node1.network.addMessageHandler("valid_message") { _, _, _ -> received++ } - val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(0)) - val validMessage = node2.network.createMessage("valid_message", data = ByteArray(0)) + val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) node2.network.send(invalidMessage, node1.network.myAddress) mockNet.runNetwork() assertEquals(0, received) @@ -104,8 +103,8 @@ class InMemoryMessagingTests { // Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so // this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops - val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(0)) - val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(0)) + val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1)) node2.network.send(invalidMessage2, node1.network.myAddress) node2.network.send(validMessage2, node1.network.myAddress) mockNet.runNetwork() diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 145a7b933b..8fdb10a1cb 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -720,6 +720,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { private val database: CordaPersistence, private val delegate: WritableTransactionStorage ) : WritableTransactionStorage, SingletonSerializeAsToken() { + override fun trackTransaction(id: SecureHash): CordaFuture { + return database.transaction { + delegate.trackTransaction(id) + } + } + override fun track(): DataFeed, SignedTransaction> { return database.transaction { delegate.track() diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 5796609f1c..1864d716e7 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable import com.codahale.metrics.MetricRegistry import com.nhaarman.mockito_kotlin.* import net.corda.core.contracts.* +import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRefFactory @@ -117,7 +118,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { } smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) - mockSMM = StateMachineManagerImpl(services, DBCheckpointStorage(), smmExecutor, database) + mockSMM = StateMachineManagerImpl(services, DBCheckpointStorage(), smmExecutor, database, newSecureRandom()) scheduler = NodeSchedulerService(testClock, database, FlowStarterImpl(smmExecutor, mockSMM), stateLoader, schedulerGatedExecutor, serverThread = smmExecutor) mockSMM.changes.subscribe { change -> if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) { diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt index 4b7972d02d..f8c1d0b5ed 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt @@ -1,6 +1,10 @@ package net.corda.node.services.messaging import net.corda.core.crypto.generateKeyPair +import net.corda.core.concurrent.CordaFuture +import com.codahale.metrics.MetricRegistry +import net.corda.core.crypto.generateKeyPair +import net.corda.core.internal.concurrent.openFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.services.RPCUserService import net.corda.node.services.RPCUserServiceImpl @@ -118,7 +122,7 @@ class ArtemisMessagingTests { messagingClient.send(message, messagingClient.myAddress) val actual: Message = receivedMessages.take() - assertEquals("first msg", String(actual.data)) + assertEquals("first msg", String(actual.data.bytes)) assertNull(receivedMessages.poll(200, MILLISECONDS)) } @@ -143,7 +147,8 @@ class ArtemisMessagingTests { val messagingClient = createMessagingClient(platformVersion = platformVersion) startNodeMessagingClient() - messagingClient.addMessageHandler(TOPIC) { message, _ -> + messagingClient.addMessageHandler(TOPIC) { message, _, handle -> + handle.acknowledge() // We ACK first so that if it fails we won't get a duplicate in [receivedMessages] receivedMessages.add(message) } // Run after the handlers are added, otherwise (some of) the messages get delivered and discarded / dead-lettered. diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 41190a8355..6d74898055 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -1,13 +1,19 @@ package net.corda.node.services.persistence -import com.google.common.primitives.Ints +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.node.services.api.Checkpoint -import net.corda.node.services.api.CheckpointStorage -import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.core.serialization.serialize import net.corda.node.internal.configureDatabase +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.statemachine.Checkpoint +import net.corda.node.services.statemachine.FlowStart +import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.ALICE import net.corda.testing.LogHelper import net.corda.testing.SerializationEnvironmentRule import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties @@ -17,14 +23,11 @@ import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test +import kotlin.streams.toList -internal fun CheckpointStorage.checkpoints(): List { - val checkpoints = mutableListOf() - forEach { - checkpoints += it - true - } - return checkpoints +internal fun CheckpointStorage.checkpoints(): List> { + val checkpoints = getAllCheckpoints().toList() + return checkpoints.map { it.second } } class DBCheckpointStorageTests { @@ -50,9 +53,9 @@ class DBCheckpointStorageTests { @Test fun `add new checkpoint`() { - val checkpoint = newCheckpoint() + val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) @@ -65,12 +68,12 @@ class DBCheckpointStorageTests { @Test fun `remove checkpoint`() { - val checkpoint = newCheckpoint() + val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) } database.transaction { - checkpointStorage.removeCheckpoint(checkpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).isEmpty() @@ -83,12 +86,12 @@ class DBCheckpointStorageTests { @Test fun `add and remove checkpoint in single commit operate`() { - val checkpoint = newCheckpoint() - val checkpoint2 = newCheckpoint() + val (id, checkpoint) = newCheckpoint() + val (id2, checkpoint2) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(checkpoint) - checkpointStorage.addCheckpoint(checkpoint2) - checkpointStorage.removeCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id2, checkpoint2) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) @@ -101,16 +104,16 @@ class DBCheckpointStorageTests { @Test fun `add two checkpoints then remove first one`() { - val firstCheckpoint = newCheckpoint() + val (id, firstCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(firstCheckpoint) + checkpointStorage.addCheckpoint(id, firstCheckpoint) } - val secondCheckpoint = newCheckpoint() + val (id2, secondCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(secondCheckpoint) + checkpointStorage.addCheckpoint(id2, secondCheckpoint) } database.transaction { - checkpointStorage.removeCheckpoint(firstCheckpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) @@ -123,9 +126,9 @@ class DBCheckpointStorageTests { @Test fun `add checkpoint and then remove after 'restart'`() { - val originalCheckpoint = newCheckpoint() + val (id, originalCheckpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(originalCheckpoint) + checkpointStorage.addCheckpoint(id, originalCheckpoint) } newCheckpointStorage() val reconstructedCheckpoint = database.transaction { @@ -135,7 +138,7 @@ class DBCheckpointStorageTests { assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) } database.transaction { - checkpointStorage.removeCheckpoint(reconstructedCheckpoint) + checkpointStorage.removeCheckpoint(id) } database.transaction { assertThat(checkpointStorage.checkpoints()).isEmpty() @@ -148,7 +151,14 @@ class DBCheckpointStorageTests { } } - private var checkpointCount = 1 - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++))) + private fun newCheckpoint(): Pair> { + val id = StateMachineRunId.createRandom() + val logic: FlowLogic<*> = object : FlowLogic() { + override fun call() {} + } + val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, "").getOrThrow() + return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + } } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index c1efc8e047..ca84d944c3 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.concurrent.Semaphore import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.ContractState @@ -48,6 +49,7 @@ import rx.Notification import rx.Observable import java.time.Instant import java.util.* +import java.util.concurrent.ExecutionException import kotlin.reflect.KClass import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -110,6 +112,19 @@ class FlowFrameworkTests { assertThat(flow.lazyTime).isNotNull() } + class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor { + var thrown = false + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + if (thrown) { + delegate.executeAction(fiber, action) + } else { + thrown = true + throw exception + } + } + } + @Test fun `exception while fiber suspended`() { bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } @@ -117,16 +132,15 @@ class FlowFrameworkTests { val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl // Before the flow runs change the suspend action to throw an exception val exceptionDuringSuspend = Exception("Thrown during suspend") - fiber.actionOnSuspend = { - throw exceptionDuringSuspend - } + val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor) + fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor)) mockNet.runNetwork() assertThatThrownBy { fiber.resultFuture.getOrThrow() }.isSameAs(exceptionDuringSuspend) assertThat(aliceNode.smm.allStateMachines).isEmpty() // Make sure the fiber does actually terminate - assertThat(fiber.isTerminated).isTrue() + assertThat(fiber.state).isEqualTo(Strand.State.WAITING) } @Test @@ -217,6 +231,8 @@ class FlowFrameworkTests { val payload = "Hello World" aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) mockNet.runNetwork() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + charlieNode.internals.acceptableLiveFiberCountOnStop = 1 val bobFlow = bobNode.getSingleFlow().first val charlieFlow = charlieNode.getSingleFlow().first assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) @@ -235,9 +251,6 @@ class FlowFrameworkTests { aliceNode sent normalEnd to charlieNode //There's no session end from the other flows as they're manually suspended ) - - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - charlieNode.internals.acceptableLiveFiberCountOnStop = 1 } @Test @@ -294,14 +307,16 @@ class FlowFrameworkTests { mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { resultFuture.getOrThrow() - }.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive + } } @Test fun `receiving unexpected session end before entering sendAndReceive`() { bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } val sessionEndReceived = Semaphore(0) - receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() } + receivedSessionMessagesObservable().filter { + it.message is ExistingSessionMessage && it.message.payload is EndSessionMessage + }.subscribe { sessionEndReceived.release() } val resultFuture = aliceNode.services.startFlow( WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture mockNet.runNetwork() @@ -337,7 +352,9 @@ class FlowFrameworkTests { mockNet.runNetwork() - assertThat(erroringFlowSteps.get()).containsExactly( + erroringFlowFuture.getOrThrow() + val flowSteps = erroringFlowSteps.get() + assertThat(flowSteps).containsExactly( Notification.createOnNext(ExceptionFlow.START_STEP), Notification.createOnError(erroringFlowFuture.get().exceptionThrown) ) @@ -354,7 +371,7 @@ class FlowFrameworkTests { assertSessionTransfers( aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, - bobNode sent erroredEnd() to aliceNode + bobNode sent errorMessage() to aliceNode ) } @@ -377,8 +394,8 @@ class FlowFrameworkTests { assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() } - assertThat(receivingFiber.isTerminated).isTrue() - assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue() + assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING) + assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).state).isEqualTo(Strand.State.WAITING) assertThat(erroringFlowSteps.get()).containsExactly( Notification.createOnNext(ExceptionFlow.START_STEP), Notification.createOnError(erroringFlow.get().exceptionThrown) @@ -387,14 +404,15 @@ class FlowFrameworkTests { assertSessionTransfers( aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, - bobNode sent erroredEnd(erroringFlow.get().exceptionThrown) to aliceNode + bobNode sent errorMessage(erroringFlow.get().exceptionThrown) to aliceNode ) // Make sure the original stack trace isn't sent down the wire - assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty() } @Test - fun `FlowException propagated in invocation chain`() { + fun `FlowException only propagated to parent`() { val charlieNode = mockNet.createNode(MockNodeParameters(legalName = CHARLIE_NAME)) val charlie = charlieNode.info.singleIdentity() @@ -402,9 +420,8 @@ class FlowFrameworkTests { bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) mockNet.runNetwork() - assertThatExceptionOfType(MyFlowException::class.java) + assertThatExceptionOfType(UnexpectedFlowEndException::class.java) .isThrownBy { receivingFiber.resultFuture.getOrThrow() } - .withMessage("Chain") } @Test @@ -436,7 +453,7 @@ class FlowFrameworkTests { aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, bobNode sent sessionData("Hello") to aliceNode, - aliceNode sent erroredEnd() to bobNode + aliceNode sent errorMessage() to bobNode ) } @@ -556,10 +573,8 @@ class FlowFrameworkTests { @Test fun `customised client flow which has annotated @InitiatingFlow again`() { - val result = aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture - mockNet.runNetwork() - assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { - result.getOrThrow() + assertThatExceptionOfType(ExecutionException::class.java).isThrownBy { + aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture }.withMessageContaining(InitiatingFlow::class.java.simpleName) } @@ -601,20 +616,20 @@ class FlowFrameworkTests { @Test fun `unknown class in session init`() { - aliceNode.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), bob) + aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob) mockNet.runNetwork() assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected - val reject = receivedSessionMessages.last().message as SessionReject - assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class") + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("Don't know not.a.real.Class") } @Test fun `non-flow class in session init`() { - aliceNode.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), bob) + aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob) mockNet.runNetwork() assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected - val reject = receivedSessionMessages.last().message as SessionReject - assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow") + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("${String::class.java.name} is not a flow") } @Test @@ -633,24 +648,6 @@ class FlowFrameworkTests { assertThat(result.getOrThrow()).isEqualTo("HelloHello") } - @Test - fun `double initiateFlow throws`() { - val future = aliceNode.services.startFlow(DoubleInitiatingFlow()).resultFuture - mockNet.runNetwork() - assertThatExceptionOfType(IllegalStateException::class.java) - .isThrownBy { future.getOrThrow() } - .withMessageContaining("Attempted to initiateFlow() twice") - } - - @InitiatingFlow - private class DoubleInitiatingFlow : FlowLogic() { - @Suspendable - override fun call() { - initiateFlow(ourIdentity) - initiateFlow(ourIdentity) - } - } - //////////////////////////////////////////////////////////////////////////////////////////////////////////// //region Helpers @@ -680,19 +677,18 @@ class FlowFrameworkTests { return observable.toFuture() } - private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): SessionInit { - return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) + private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { + return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) } - - private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") - private fun sessionData(payload: Any) = SessionData(0, payload.serialize()) - private val normalEnd = NormalSessionEnd(0) - private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) + private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) + private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) + private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) + private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { services.networkService.apply { val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) - send(createMessage(StateMachineManagerImpl.sessionTopic, message.serialize().bytes), address) + send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) } } @@ -707,7 +703,9 @@ class FlowFrameworkTests { } private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { - val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null + val isPayloadTransfer: Boolean get() = + message is ExistingSessionMessage && message.payload is DataSessionMessage || + message is InitialSessionMessage && message.firstPayload != null override fun toString(): String = "$from sent $message to $to" } @@ -716,7 +714,7 @@ class FlowFrameworkTests { } private fun Observable.toSessionTransfers(): Observable { - return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map { + return filter { it.message.topic == FlowMessagingImpl.sessionTopic }.map { val from = it.sender.id val message = it.message.data.deserialize() SessionTransfer(from, sanitise(message), it.recipients) @@ -724,12 +722,23 @@ class FlowFrameworkTests { } private fun sanitise(message: SessionMessage) = when (message) { - is SessionData -> message.copy(recipientSessionId = 0) - is SessionInit -> message.copy(initiatorSessionId = 0, appName = "") - is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "") - is NormalSessionEnd -> message.copy(recipientSessionId = 0) - is ErrorSessionEnd -> message.copy(recipientSessionId = 0) - else -> message + is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "") + is ExistingSessionMessage -> { + val payload = message.payload + message.copy( + recipientSessionId = SessionId(0), + payload = when (payload) { + is ConfirmSessionMessage -> payload.copy( + initiatedSessionId = SessionId(0), + initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "") + ) + is ErrorSessionMessage -> payload.copy( + errorId = 0 + ) + else -> payload + } + ) + } } private infix fun StartedNode.sent(message: SessionMessage): Pair = Pair(internals.id, message) diff --git a/perftestcordapp/src/test/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/TwoPartyTradeFlowTest.kt b/perftestcordapp/src/test/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/TwoPartyTradeFlowTest.kt index bc8b247fd2..3cbd3eb293 100644 --- a/perftestcordapp/src/test/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/TwoPartyTradeFlowTest.kt +++ b/perftestcordapp/src/test/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/TwoPartyTradeFlowTest.kt @@ -27,6 +27,7 @@ import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.node.services.Vault import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction @@ -37,10 +38,10 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.unwrap import net.corda.node.internal.StartedNode -import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.persistence.DBTransactionStorage +import net.corda.node.services.statemachine.Checkpoint import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.* import net.corda.testing.node.* @@ -56,21 +57,14 @@ import java.io.ByteArrayOutputStream import java.util.* import java.util.jar.JarOutputStream import java.util.zip.ZipEntry +import kotlin.streams.toList import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertTrue - -/** - * Copied from DBCheckpointStorageTests as it is required as helper for this test - */ -internal fun CheckpointStorage.checkpoints(): List { - val checkpoints = mutableListOf() - forEach { - checkpoints += it - true - } - return checkpoints +internal fun CheckpointStorage.checkpoints(): List> { + val checkpoints = getAllCheckpoints().toList() + return checkpoints.map { it.second } } /** @@ -740,6 +734,12 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { private val database: CordaPersistence, private val delegate: WritableTransactionStorage ) : WritableTransactionStorage, SingletonSerializeAsToken() { + override fun trackTransaction(id: SecureHash): CordaFuture { + return database.transaction { + delegate.trackTransaction(id) + } + } + override fun track(): DataFeed, SignedTransaction> { return database.transaction { delegate.track() diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt index 3411eb5d0d..5d8b7f2ad4 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt @@ -15,9 +15,7 @@ import net.corda.core.serialization.deserialize import net.corda.core.utilities.ProgressTracker import net.corda.netmap.VisualiserViewModel.Style import net.corda.netmap.simulation.IRSSimulation -import net.corda.node.services.statemachine.SessionConfirm -import net.corda.node.services.statemachine.SessionEnd -import net.corda.node.services.statemachine.SessionInit +import net.corda.node.services.statemachine.* import net.corda.testing.chooseIdentity import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockNetwork @@ -342,12 +340,16 @@ class NetworkMapVisualiser : Application() { private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { // Loopback messages are boring. if (transfer.sender == transfer.recipients) return false - val message = transfer.message.data.deserialize() + val message = transfer.message.data.deserialize() return when (message) { - is SessionEnd -> false - is SessionConfirm -> false - is SessionInit -> message.firstPayload != null - else -> true + is InitialSessionMessage -> message.firstPayload != null + is ExistingSessionMessage -> when (message.payload) { + is ConfirmSessionMessage -> false + is DataSessionMessage -> true + is ErrorSessionMessage -> true + is RejectSessionMessage -> true + is EndSessionMessage -> false + } } } } diff --git a/settings.gradle b/settings.gradle index 34db991711..6a1ca989ff 100644 --- a/settings.gradle +++ b/settings.gradle @@ -20,6 +20,7 @@ include 'experimental:sandbox' include 'experimental:quasar-hook' include 'experimental:kryo-hook' include 'experimental:intellij-plugin' +include 'experimental:flow-hook' include 'verifier' include 'test-common' include 'test-utils' diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt index dc0971484f..d8077fc978 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt @@ -634,7 +634,7 @@ class DriverDSL( throw ListenProcessDeathException(rpcAddress, processDeathFuture.getOrThrow()) } val connection = connectionFuture.getOrThrow() - shutdownManager.registerShutdown(connection::close) + shutdownManager.registerShutdown(connection::forceClose) connection.proxy } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index 427179a1ef..6135760dbb 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -13,9 +13,12 @@ import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace import net.corda.node.services.messaging.* +import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging @@ -57,7 +60,7 @@ class InMemoryMessagingNetwork( @CordaSerializable data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) { - override fun toString() = "${message.topicSession} from '$sender' to '$recipients'" + override fun toString() = "${message.topic} from '$sender' to '$recipients'" } // All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false @@ -242,17 +245,17 @@ class InMemoryMessagingNetwork( _sentMessages.onNext(transfer) } - data class InMemoryMessage(override val topicSession: TopicSession, - override val data: ByteArray, - override val uniqueMessageId: UUID, - override val debugTimestamp: Instant = Instant.now()) : Message { - override fun toString() = "$topicSession#${String(data)}" + data class InMemoryMessage(override val topic: String, + override val data: ByteSequence, + override val uniqueMessageId: DeduplicationId, + override val debugTimestamp: Instant = Instant.now()) : Message { + override fun toString() = "$topic#${String(data.bytes)}" } - private data class InMemoryReceivedMessage(override val topicSession: TopicSession, - override val data: ByteArray, + private data class InMemoryReceivedMessage(override val topic: String, + override val data: ByteSequence, override val platformVersion: Int, - override val uniqueMessageId: UUID, + override val uniqueMessageId: DeduplicationId, override val debugTimestamp: Instant, override val peer: CordaX500Name) : ReceivedMessage @@ -268,8 +271,7 @@ class InMemoryMessagingNetwork( private val peerHandle: PeerHandle, private val executor: AffinityExecutor, private val database: CordaPersistence) : SingletonSerializeAsToken(), MessagingService { - inner class Handler(val topicSession: TopicSession, - val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration + inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration @Volatile private var running = true @@ -280,7 +282,7 @@ class InMemoryMessagingNetwork( } private val state = ThreadBox(InnerState()) - private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) + private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) override val myAddress: PeerHandle get() = peerHandle @@ -302,13 +304,10 @@ class InMemoryMessagingNetwork( } } - override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration - = addMessageHandler(TopicSession(topic, sessionID), callback) - - override fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { + override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { check(running) val (handler, transfers) = state.locked { - val handler = Handler(topicSession, callback).apply { handlers.add(this) } + val handler = Handler(topic, callback).apply { handlers.add(this) } val pending = ArrayList() database.transaction { pending.addAll(pendingRedelivery) @@ -354,8 +353,8 @@ class InMemoryMessagingNetwork( override fun cancelRedelivery(retryId: Long) {} /** Returns the given (topic & session, data) pair as a newly created message object. */ - override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { - return InMemoryMessage(topicSession, data, uuid) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: DeduplicationId): Message { + return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId) } /** @@ -388,14 +387,14 @@ class InMemoryMessagingNetwork( while (deliverTo == null) { val transfer = (if (block) q.take() else q.poll()) ?: return null deliverTo = state.locked { - val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topicSession == it.topicSession } + val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession } if (matchingHandlers.isEmpty()) { // Got no handlers for this message yet. Keep the message around and attempt redelivery after a new // handler has been registered. The purpose of this path is to make unit tests that have multi-threading // reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at // least sometimes. - log.warn("Message to ${transfer.message.topicSession} could not be delivered") + log.warn("Message to ${transfer.message.topic} could not be delivered") database.transaction { pendingRedelivery.add(transfer) } @@ -419,7 +418,13 @@ class InMemoryMessagingNetwork( database.transaction { for (handler in deliverTo) { try { - handler.callback(transfer.toReceivedMessage(), handler) + val acknowledgeHandle = object : AcknowledgeHandle { + override fun acknowledge() { + } + override fun persistDeduplicationId() { + } + } + handler.callback(transfer.toReceivedMessage(), handler, acknowledgeHandle) } catch (e: Exception) { log.error("Caught exception in handler for $this/${handler.topicSession}", e) } @@ -436,8 +441,8 @@ class InMemoryMessagingNetwork( } private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage( - message.topicSession, - message.data.copyOf(), // Kryo messes with the buffer so give each client a unique copy + message.topic, + OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy 1, message.uniqueMessageId, message.debugTimestamp, diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt index 007a78c542..588551a145 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -4,12 +4,14 @@ import com.google.common.collect.MutableClassToInstanceMap import com.typesafe.config.Config import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions +import net.corda.core.concurrent.CordaFuture import net.corda.core.cordapp.CordappProvider import net.corda.core.crypto.* import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.CordaX500Name import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowProgressHandle @@ -17,6 +19,7 @@ import net.corda.core.node.* import net.corda.core.node.services.* import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.node.VersionInfo import net.corda.node.internal.StateLoaderImpl @@ -285,6 +288,10 @@ class MockStateMachineRecordedTransactionMappingStorage( ) : StateMachineRecordedTransactionMappingStorage by storage open class MockTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() { + override fun trackTransaction(id: SecureHash): CordaFuture { + return txns[id]?.let { doneFuture(it) } ?: _updatesPublisher.filter { it.id == id }.toFuture() + } + override fun track(): DataFeed, SignedTransaction> { return DataFeed(txns.values.toList(), _updatesPublisher) } @@ -327,4 +334,4 @@ fun createMockCordaService(serviceHub: MockServices, serv } } return MockAppServiceHubImpl(serviceHub, serviceConstructor).serviceInstance -} \ No newline at end of file +} diff --git a/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Launcher.kt b/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Launcher.kt index 59e9ce5d92..36a02344e6 100644 --- a/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Launcher.kt +++ b/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Launcher.kt @@ -1,6 +1,5 @@ package com.r3.corda.jmeter -import com.sun.javaws.exceptions.InvalidArgumentException import net.corda.core.internal.div import org.apache.jmeter.JMeter import org.slf4j.LoggerFactory @@ -68,7 +67,7 @@ class Launcher { if (args[index] == "-XsshUser") { ++index if (index == args.size || args[index].startsWith("-")) { - throw InvalidArgumentException(args) + throw IllegalArgumentException(args.toList().toString()) } userName = args[index] } else if (args[index] == "-Xssh") { diff --git a/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Ssh.kt b/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Ssh.kt index a01af325e5..c1bea9430b 100644 --- a/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Ssh.kt +++ b/tools/jmeter/src/main/kotlin/com/r3/corda/jmeter/Ssh.kt @@ -2,7 +2,6 @@ package com.r3.corda.jmeter import com.jcraft.jsch.JSch import com.jcraft.jsch.Session -import com.sun.javaws.exceptions.InvalidArgumentException import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.internal.addShutdownHook import org.slf4j.LoggerFactory @@ -28,7 +27,7 @@ class Ssh { if (args[index] == "-XsshUser") { ++index if (index == args.size || args[index].startsWith("-")) { - throw InvalidArgumentException(args) + throw IllegalArgumentException(args.toList().toString()) } userName = args[index] } else if (args[index] == "-Xssh") {