Merge remote-tracking branch 'origin/release/os/4.6' into christians/ENT-5273-update-from-os-4.6

This commit is contained in:
Christian Sailer 2020-07-23 14:02:18 +01:00
commit db94f65d8a
55 changed files with 1579 additions and 292 deletions

View File

@ -5398,6 +5398,10 @@ public interface net.corda.core.schemas.QueryableState extends net.corda.core.co
## ##
public interface net.corda.core.schemas.StatePersistable public interface net.corda.core.schemas.StatePersistable
## ##
public interface net.corda.core.serialization.CheckpointCustomSerializer
public abstract OBJ fromProxy(PROXY)
public abstract PROXY toProxy(OBJ)
##
public interface net.corda.core.serialization.ClassWhitelist public interface net.corda.core.serialization.ClassWhitelist
public abstract boolean hasListed(Class<?>) public abstract boolean hasListed(Class<?>)
## ##

View File

@ -1,3 +1,14 @@
#!groovy
/**
* Jenkins pipeline to build Corda OS release with JDK11
*/
/**
* Kill already started job.
* Assume new commit takes precendence and results from previous
* unfinished builds are not required.
* This feature doesn't play well with disableConcurrentBuilds() option
*/
@Library('corda-shared-build-pipeline-steps') @Library('corda-shared-build-pipeline-steps')
import static com.r3.build.BuildControl.killAllExistingBuildsForJob import static com.r3.build.BuildControl.killAllExistingBuildsForJob
@ -19,16 +30,16 @@ if (isReleaseTag) {
switch (env.TAG_NAME) { switch (env.TAG_NAME) {
case ~/.*-RC\d+(-.*)?/: nexusIqStage = "stage-release"; break; case ~/.*-RC\d+(-.*)?/: nexusIqStage = "stage-release"; break;
case ~/.*-HC\d+(-.*)?/: nexusIqStage = "stage-release"; break; case ~/.*-HC\d+(-.*)?/: nexusIqStage = "stage-release"; break;
default: nexusIqStage = "operate" default: nexusIqStage = "release"
} }
} }
pipeline { pipeline {
agent { agent { label 'k8s' }
label 'k8s'
}
options { options {
timestamps() timestamps()
timeout(time: 3, unit: 'HOURS') timeout(time: 3, unit: 'HOURS')
buildDiscarder(logRotator(daysToKeepStr: '14', artifactDaysToKeepStr: '14'))
} }
environment { environment {
@ -48,14 +59,15 @@ pipeline {
sh "./gradlew --no-daemon clean jar" sh "./gradlew --no-daemon clean jar"
script { script {
sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties" sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties"
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: //'").trim() /* every build related to Corda X.Y (GA, RC, HC, patch or snapshot) uses the same NexusIQ application */
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: \\([0-9]\\+\\.[0-9]\\+\\).*\$/\\1/'").trim()
def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim() def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim()
def artifactId = 'corda' def artifactId = 'corda'
nexusAppId = "jenkins-${groupId}-${artifactId}-jdk11-${version}" nexusAppId = "jenkins-${groupId}-${artifactId}-jdk11-${version}"
} }
nexusPolicyEvaluation ( nexusPolicyEvaluation (
failBuildOnNetworkError: false, failBuildOnNetworkError: false,
iqApplication: manualApplication(nexusAppId), iqApplication: selectedApplication(nexusAppId), // application *has* to exist before a build starts!
iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']], iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']],
iqStage: nexusIqStage iqStage: nexusIqStage
) )
@ -132,7 +144,7 @@ pipeline {
rtGradleDeployer( rtGradleDeployer(
id: 'deployer', id: 'deployer',
serverId: 'R3-Artifactory', serverId: 'R3-Artifactory',
repo: 'r3-corda-releases' repo: 'corda-releases'
) )
rtGradleRun( rtGradleRun(
usesPlugin: true, usesPlugin: true,
@ -153,7 +165,7 @@ pipeline {
post { post {
always { always {
archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false
junit testResults: '**/build/test-results-xml/**/*.xml', allowEmptyResults: true junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true
} }
cleanup { cleanup {
deleteDir() /* clean up our workspace */ deleteDir() /* clean up our workspace */

View File

@ -65,7 +65,7 @@ pipeline {
post { post {
always { always {
archiveArtifacts allowEmptyArchive: true, artifacts: '**/logs/**/*.log' archiveArtifacts allowEmptyArchive: true, artifacts: '**/logs/**/*.log'
junit testResults: '**/build/test-results/**/*.xml', keepLongStdio: true, allowEmptyResults: true junit testResults: '**/build/test-results/**/*.xml', keepLongStdio: true
bat '.ci/kill_corda_procs.cmd' bat '.ci/kill_corda_procs.cmd'
} }
cleanup { cleanup {
@ -87,7 +87,7 @@ pipeline {
post { post {
always { always {
archiveArtifacts allowEmptyArchive: true, artifacts: '**/logs/**/*.log' archiveArtifacts allowEmptyArchive: true, artifacts: '**/logs/**/*.log'
junit testResults: '**/build/test-results/**/*.xml', keepLongStdio: true, allowEmptyResults: true junit testResults: '**/build/test-results/**/*.xml', keepLongStdio: true
bat '.ci/kill_corda_procs.cmd' bat '.ci/kill_corda_procs.cmd'
} }
cleanup { cleanup {

View File

@ -84,7 +84,7 @@ pipeline {
post { post {
always { always {
archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false
junit testResults: '**/build/test-results-xml/**/*.xml', allowEmptyResults: true, keepLongStdio: true junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true
} }
cleanup { cleanup {
deleteDir() /* clean up our workspace */ deleteDir() /* clean up our workspace */

View File

@ -4,7 +4,7 @@ import static com.r3.build.BuildControl.killAllExistingBuildsForJob
killAllExistingBuildsForJob(env.JOB_NAME, env.BUILD_NUMBER.toInteger()) killAllExistingBuildsForJob(env.JOB_NAME, env.BUILD_NUMBER.toInteger())
pipeline { pipeline {
agent { label 'k8s' } agent { label 'standard' }
options { options {
timestamps() timestamps()
timeout(time: 3, unit: 'HOURS') timeout(time: 3, unit: 'HOURS')

View File

@ -49,14 +49,15 @@ pipeline {
sh "./gradlew --no-daemon clean jar" sh "./gradlew --no-daemon clean jar"
script { script {
sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties" sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties"
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: //'").trim() /* every build related to Corda X.Y (GA, RC, HC, patch or snapshot) uses the same NexusIQ application */
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: \\([0-9]\\+\\.[0-9]\\+\\).*\$/\\1/'").trim()
def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim() def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim()
def artifactId = 'corda' def artifactId = 'corda'
nexusAppId = "jenkins-${groupId}-${artifactId}-${version}" nexusAppId = "jenkins-${groupId}-${artifactId}-${version}"
} }
nexusPolicyEvaluation ( nexusPolicyEvaluation (
failBuildOnNetworkError: false, failBuildOnNetworkError: false,
iqApplication: manualApplication(nexusAppId), iqApplication: selectedApplication(nexusAppId), // application *has* to exist before a build starts!
iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']], iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']],
iqStage: nexusIqStage iqStage: nexusIqStage
) )

View File

@ -30,7 +30,7 @@ if (isReleaseTag) {
switch (env.TAG_NAME) { switch (env.TAG_NAME) {
case ~/.*-RC\d+(-.*)?/: nexusIqStage = "stage-release"; break; case ~/.*-RC\d+(-.*)?/: nexusIqStage = "stage-release"; break;
case ~/.*-HC\d+(-.*)?/: nexusIqStage = "stage-release"; break; case ~/.*-HC\d+(-.*)?/: nexusIqStage = "stage-release"; break;
default: nexusIqStage = "operate" default: nexusIqStage = "release"
} }
} }
@ -61,14 +61,15 @@ pipeline {
sh "./gradlew --no-daemon clean jar" sh "./gradlew --no-daemon clean jar"
script { script {
sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties" sh "./gradlew --no-daemon properties | grep -E '^(version|group):' >version-properties"
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: //'").trim() /* every build related to Corda X.Y (GA, RC, HC, patch or snapshot) uses the same NexusIQ application */
def version = sh (returnStdout: true, script: "grep ^version: version-properties | sed -e 's/^version: \\([0-9]\\+\\.[0-9]\\+\\).*\$/\\1/'").trim()
def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim() def groupId = sh (returnStdout: true, script: "grep ^group: version-properties | sed -e 's/^group: //'").trim()
def artifactId = 'corda' def artifactId = 'corda'
nexusAppId = "jenkins-${groupId}-${artifactId}-${version}" nexusAppId = "jenkins-${groupId}-${artifactId}-${version}"
} }
nexusPolicyEvaluation ( nexusPolicyEvaluation (
failBuildOnNetworkError: false, failBuildOnNetworkError: false,
iqApplication: manualApplication(nexusAppId), iqApplication: selectedApplication(nexusAppId), // application *has* to exist before a build starts!
iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']], iqScanPatterns: [[scanPattern: 'node/capsule/build/libs/corda*.jar']],
iqStage: nexusIqStage iqStage: nexusIqStage
) )
@ -174,11 +175,10 @@ pipeline {
} }
} }
post { post {
always { always {
archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false
junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true, allowEmptyResults: true junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true
script { script {
try { try {

2
Jenkinsfile vendored
View File

@ -79,7 +79,7 @@ pipeline {
post { post {
always { always {
archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false archiveArtifacts artifacts: '**/pod-logs/**/*.log', fingerprint: false
junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true, allowEmptyResults: true junit testResults: '**/build/test-results-xml/**/*.xml', keepLongStdio: true
} }
cleanup { cleanup {
deleteDir() /* clean up our workspace */ deleteDir() /* clean up our workspace */

View File

@ -11,7 +11,7 @@ java8MinUpdateVersion=171
# When incrementing platformVersion make sure to update # # When incrementing platformVersion make sure to update #
# net.corda.core.internal.CordaUtilsKt.PLATFORM_VERSION as well. # # net.corda.core.internal.CordaUtilsKt.PLATFORM_VERSION as well. #
# ***************************************************************# # ***************************************************************#
platformVersion=7 platformVersion=8
guavaVersion=28.0-jre guavaVersion=28.0-jre
# Quasar version to use with Java 8: # Quasar version to use with Java 8:
quasarVersion=0.7.12_r3 quasarVersion=0.7.12_r3

View File

@ -167,7 +167,7 @@ class FlowExternalAsyncOperationTest : AbstractFlowExternalOperationTest() {
@Suspendable @Suspendable
override fun testCode(): Any = override fun testCode(): Any =
await(ExternalAsyncOperation(serviceHub) { _, _ -> await(ExternalAsyncOperation(serviceHub) { serviceHub, _ ->
serviceHub.cordaService(FutureService::class.java).createFuture() serviceHub.cordaService(FutureService::class.java).createFuture()
}) })
} }

View File

@ -7,6 +7,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.internal.cordapp.CordappImpl.Companion.UNKNOWN_VALUE import net.corda.core.internal.cordapp.CordappImpl.Companion.UNKNOWN_VALUE
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
@ -29,6 +30,7 @@ import java.net.URL
* @property services List of RPC services * @property services List of RPC services
* @property serializationWhitelists List of Corda plugin registries * @property serializationWhitelists List of Corda plugin registries
* @property serializationCustomSerializers List of serializers * @property serializationCustomSerializers List of serializers
* @property checkpointCustomSerializers List of serializers for checkpoints
* @property customSchemas List of custom schemas * @property customSchemas List of custom schemas
* @property allFlows List of all flow classes * @property allFlows List of all flow classes
* @property jarPath The path to the JAR for this CorDapp * @property jarPath The path to the JAR for this CorDapp
@ -49,6 +51,7 @@ interface Cordapp {
val services: List<Class<out SerializeAsToken>> val services: List<Class<out SerializeAsToken>>
val serializationWhitelists: List<SerializationWhitelist> val serializationWhitelists: List<SerializationWhitelist>
val serializationCustomSerializers: List<SerializationCustomSerializer<*, *>> val serializationCustomSerializers: List<SerializationCustomSerializer<*, *>>
val checkpointCustomSerializers: List<CheckpointCustomSerializer<*, *>>
val customSchemas: Set<MappedSchema> val customSchemas: Set<MappedSchema>
val allFlows: List<Class<out FlowLogic<*>>> val allFlows: List<Class<out FlowLogic<*>>>
val jarPath: URL val jarPath: URL

View File

@ -25,6 +25,7 @@ import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -378,6 +379,22 @@ abstract class FlowLogic<out T> {
stateMachine.suspend(request, maySkipCheckpoint) stateMachine.suspend(request, maySkipCheckpoint)
} }
/**
* Closes the provided sessions and performs cleanup of any resources tied to these sessions.
*
* Note that sessions are closed automatically when the corresponding top-level flow terminates.
* So, it's beneficial to eagerly close them in long-lived flows that might have many open sessions that are not needed anymore and consume resources (e.g. memory, disk etc.).
* A closed session cannot be used anymore, e.g. to send or receive messages. So, you have to ensure you are calling this method only when the provided sessions are not going to be used anymore.
* As a result, any operations on a closed session will fail with an [UnexpectedFlowEndException].
* When a session is closed, the other side is informed and the session is closed there too eventually.
* To prevent misuse of the API, if there is an attempt to close an uninitialised session the invocation will fail with an [IllegalStateException].
*/
@Suspendable
fun close(sessions: NonEmptySet<FlowSession>) {
val request = FlowIORequest.CloseSessions(sessions)
stateMachine.suspend(request, false)
}
/** /**
* Invokes the given subflow. This function returns once the subflow completes successfully with the result * Invokes the given subflow. This function returns once the subflow completes successfully with the result
* returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the * returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the

View File

@ -191,6 +191,19 @@ abstract class FlowSession {
*/ */
@Suspendable @Suspendable
abstract fun send(payload: Any) abstract fun send(payload: Any)
/**
* Closes this session and performs cleanup of any resources tied to this session.
*
* Note that sessions are closed automatically when the corresponding top-level flow terminates.
* So, it's beneficial to eagerly close them in long-lived flows that might have many open sessions that are not needed anymore and consume resources (e.g. memory, disk etc.).
* A closed session cannot be used anymore, e.g. to send or receive messages. So, you have to ensure you are calling this method only when the session is not going to be used anymore.
* As a result, any operations on a closed session will fail with an [UnexpectedFlowEndException].
* When a session is closed, the other side is informed and the session is closed there too eventually.
* To prevent misuse of the API, if there is an attempt to close an uninitialised session the invocation will fail with an [IllegalStateException].
*/
@Suspendable
abstract fun close()
} }
/** /**

View File

@ -28,7 +28,7 @@ import java.util.jar.JarInputStream
// *Internal* Corda-specific utilities. // *Internal* Corda-specific utilities.
const val PLATFORM_VERSION = 7 const val PLATFORM_VERSION = 8
fun ServicesForResolution.ensureMinimumPlatformVersion(requiredMinPlatformVersion: Int, feature: String) { fun ServicesForResolution.ensureMinimumPlatformVersion(requiredMinPlatformVersion: Int, feature: String) {
checkMinimumPlatformVersion(networkParameters.minimumPlatformVersion, requiredMinPlatformVersion, feature) checkMinimumPlatformVersion(networkParameters.minimumPlatformVersion, requiredMinPlatformVersion, feature)

View File

@ -55,6 +55,13 @@ sealed class FlowIORequest<out R : Any> {
}}, shouldRetrySend=$shouldRetrySend)" }}, shouldRetrySend=$shouldRetrySend)"
} }
/**
* Closes the specified sessions.
*
* @property sessions the sessions to be closed.
*/
data class CloseSessions(val sessions: NonEmptySet<FlowSession>): FlowIORequest<Unit>()
/** /**
* Wait for a transaction to be committed to the database. * Wait for a transaction to be committed to the database.
* *

View File

@ -9,6 +9,7 @@ import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.notary.NotaryService import net.corda.core.internal.notary.NotaryService
import net.corda.core.internal.toPath import net.corda.core.internal.toPath
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
@ -25,6 +26,7 @@ data class CordappImpl(
override val services: List<Class<out SerializeAsToken>>, override val services: List<Class<out SerializeAsToken>>,
override val serializationWhitelists: List<SerializationWhitelist>, override val serializationWhitelists: List<SerializationWhitelist>,
override val serializationCustomSerializers: List<SerializationCustomSerializer<*, *>>, override val serializationCustomSerializers: List<SerializationCustomSerializer<*, *>>,
override val checkpointCustomSerializers: List<CheckpointCustomSerializer<*, *>>,
override val customSchemas: Set<MappedSchema>, override val customSchemas: Set<MappedSchema>,
override val allFlows: List<Class<out FlowLogic<*>>>, override val allFlows: List<Class<out FlowLogic<*>>>,
override val jarPath: URL, override val jarPath: URL,
@ -79,6 +81,7 @@ data class CordappImpl(
services = emptyList(), services = emptyList(),
serializationWhitelists = emptyList(), serializationWhitelists = emptyList(),
serializationCustomSerializers = emptyList(), serializationCustomSerializers = emptyList(),
checkpointCustomSerializers = emptyList(),
customSchemas = emptySet(), customSchemas = emptySet(),
jarPath = Paths.get("").toUri().toURL(), jarPath = Paths.get("").toUri().toURL(),
info = UNKNOWN_INFO, info = UNKNOWN_INFO,

View File

@ -25,3 +25,26 @@ interface SerializationCustomSerializer<OBJ, PROXY> {
*/ */
fun fromProxy(proxy: PROXY): OBJ fun fromProxy(proxy: PROXY): OBJ
} }
/**
* Allows CorDapps to provide custom serializers for classes that do not serialize successfully during a checkpoint.
* In this case, a proxy serializer can be written that implements this interface whose purpose is to move between
* unserializable types and an intermediate representation.
*
* NOTE: Only implement this interface if you have a class that triggers an error during normal checkpoint
* serialization/deserialization.
*/
@KeepForDJVM
interface CheckpointCustomSerializer<OBJ, PROXY> {
/**
* Should facilitate the conversion of the third party object into the serializable
* local class specified by [PROXY]
*/
fun toProxy(obj: OBJ): PROXY
/**
* Should facilitate the conversion of the proxy object into a new instance of the
* unserializable type
*/
fun fromProxy(proxy: PROXY): OBJ
}

View File

@ -56,6 +56,10 @@ interface CheckpointSerializationContext {
* otherwise they appear as new copies of the object. * otherwise they appear as new copies of the object.
*/ */
val objectReferencesEnabled: Boolean val objectReferencesEnabled: Boolean
/**
* User defined custom serializers for use in checkpoint serialization.
*/
val checkpointCustomSerializers: Iterable<CheckpointCustomSerializer<*,*>>
/** /**
* Helper method to return a new context based on this context with the property added. * Helper method to return a new context based on this context with the property added.
@ -86,6 +90,11 @@ interface CheckpointSerializationContext {
* A shallow copy of this context but with the given encoding whitelist. * A shallow copy of this context but with the given encoding whitelist.
*/ */
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given custom serializers.
*/
fun withCheckpointCustomSerializers(checkpointCustomSerializers: Iterable<CheckpointCustomSerializer<*, *>>): CheckpointSerializationContext
} }
/* /*

View File

@ -0,0 +1,103 @@
package net.corda.nodeapi.internal.serialization.kryo
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.serialization.internal.amqp.CORDAPP_TYPE
import java.lang.reflect.Type
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.jvmErasure
/**
* Adapts CheckpointCustomSerializer for use in Kryo
*/
internal class CustomSerializerCheckpointAdaptor<OBJ, PROXY>(private val userSerializer : CheckpointCustomSerializer<OBJ, PROXY>) : Serializer<OBJ>() {
/**
* The class name of the serializer we are adapting.
*/
val serializerName: String = userSerializer.javaClass.name
/**
* The input type of this custom serializer.
*/
val cordappType: Type
/**
* Check we have access to the types specified on the CheckpointCustomSerializer interface.
*
* Throws UnableToDetermineSerializerTypesException if the types are missing.
*/
init {
val types: List<Type> = userSerializer::class
.supertypes
.filter { it.jvmErasure == CheckpointCustomSerializer::class }
.flatMap { it.arguments }
.mapNotNull { it.type?.javaType }
// We are expecting a cordapp type and a proxy type.
// We will only use the cordapp type in this class
// but we want to check both are present.
val typeParameterCount = 2
if (types.size != typeParameterCount) {
throw UnableToDetermineSerializerTypesException("Unable to determine serializer parent types")
}
cordappType = types[CORDAPP_TYPE]
}
/**
* Serialize obj to the Kryo stream.
*/
override fun write(kryo: Kryo, output: Output, obj: OBJ) {
fun <T> writeToKryo(obj: T) = kryo.writeClassAndObject(output, obj)
// Write serializer type
writeToKryo(serializerName)
// Write proxy object
writeToKryo(userSerializer.toProxy(obj))
}
/**
* Deserialize an object from the Kryo stream.
*/
override fun read(kryo: Kryo, input: Input, type: Class<OBJ>): OBJ {
@Suppress("UNCHECKED_CAST")
fun <T> readFromKryo() = kryo.readClassAndObject(input) as T
// Check the serializer type
checkSerializerType(readFromKryo())
// Read the proxy object
return userSerializer.fromProxy(readFromKryo())
}
/**
* Throws a `CustomCheckpointSerializersHaveChangedException` if the serializer type in the kryo stream does not match the serializer
* type for this custom serializer.
*
* @param checkpointSerializerType Serializer type from the Kryo stream
*/
private fun checkSerializerType(checkpointSerializerType: String) {
if (checkpointSerializerType != serializerName)
throw CustomCheckpointSerializersHaveChangedException("The custom checkpoint serializers have changed while checkpoints exist. " +
"Please restore the CorDapps to when this checkpoint was created.")
}
}
/**
* Thrown when the input/output types are missing from the custom serializer.
*/
class UnableToDetermineSerializerTypesException(message: String) : RuntimeException(message)
/**
* Thrown when the custom serializer is found to be reading data from another type of custom serializer.
*
* This was expected to happen if the user adds or removes CorDapps while checkpoints exist but it turned out that registering serializers
* as default made the system reliable.
*/
class CustomCheckpointSerializersHaveChangedException(message: String) : RuntimeException(message)

View File

@ -10,12 +10,14 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializer import net.corda.core.serialization.internal.CheckpointSerializer
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.loggerFor
import net.corda.serialization.internal.AlwaysAcceptEncodingWhitelist import net.corda.serialization.internal.AlwaysAcceptEncodingWhitelist
import net.corda.serialization.internal.ByteBufferInputStream import net.corda.serialization.internal.ByteBufferInputStream
import net.corda.serialization.internal.CheckpointSerializationContextImpl import net.corda.serialization.internal.CheckpointSerializationContextImpl
@ -40,10 +42,10 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
} }
object KryoCheckpointSerializer : CheckpointSerializer { object KryoCheckpointSerializer : CheckpointSerializer {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>() private val kryoPoolsForContexts = ConcurrentHashMap<Triple<ClassWhitelist, ClassLoader, Iterable<CheckpointCustomSerializer<*,*>>>, KryoPool>()
private fun getPool(context: CheckpointSerializationContext): KryoPool { private fun getPool(context: CheckpointSerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { return kryoPoolsForContexts.computeIfAbsent(Triple(context.whitelist, context.deserializationClassLoader, context.checkpointCustomSerializers)) {
KryoPool.Builder { KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) } val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
@ -56,12 +58,60 @@ object KryoCheckpointSerializer : CheckpointSerializer {
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second classLoader = it.second
// Add custom serializers
val customSerializers = buildCustomSerializerAdaptors(context)
warnAboutDuplicateSerializers(customSerializers)
val classToSerializer = mapInputClassToCustomSerializer(context.deserializationClassLoader, customSerializers)
addDefaultCustomSerializers(this, classToSerializer)
} }
}.build() }.build()
} }
} }
/**
* Returns a sorted list of CustomSerializerCheckpointAdaptor based on the custom serializers inside context.
*
* The adaptors are sorted by serializerName which maps to javaClass.name for the serializer class
*/
private fun buildCustomSerializerAdaptors(context: CheckpointSerializationContext) =
context.checkpointCustomSerializers.map { CustomSerializerCheckpointAdaptor(it) }.sortedBy { it.serializerName }
/**
* Returns a list of pairs where the first element is the input class of the custom serializer and the second element is the
* custom serializer.
*/
private fun mapInputClassToCustomSerializer(classLoader: ClassLoader, customSerializers: Iterable<CustomSerializerCheckpointAdaptor<*, *>>) =
customSerializers.map { getInputClassForCustomSerializer(classLoader, it) to it }
/**
* Returns the Class object for the serializers input type.
*/
private fun getInputClassForCustomSerializer(classLoader: ClassLoader, customSerializer: CustomSerializerCheckpointAdaptor<*, *>): Class<*> {
val typeNameWithoutGenerics = customSerializer.cordappType.typeName.substringBefore('<')
return classLoader.loadClass(typeNameWithoutGenerics)
}
/**
* Emit a warning if two or more custom serializers are found for the same input type.
*/
private fun warnAboutDuplicateSerializers(customSerializers: Iterable<CustomSerializerCheckpointAdaptor<*,*>>) =
customSerializers
.groupBy({ it.cordappType }, { it.serializerName })
.filter { (_, serializerNames) -> serializerNames.distinct().size > 1 }
.forEach { (inputType, serializerNames) -> loggerFor<KryoCheckpointSerializer>().warn("Duplicate custom checkpoint serializer for type $inputType. Serializers: ${serializerNames.joinToString(", ")}") }
/**
* Register all custom serializers as default, this class + subclass, registrations.
*
* Serializers registered before this will take priority. This needs to run after registrations we want to keep otherwise it may
* replace them.
*/
private fun addDefaultCustomSerializers(kryo: Kryo, classToSerializer: Iterable<Pair<Class<*>, CustomSerializerCheckpointAdaptor<*, *>>>) =
classToSerializer
.forEach { (clazz, customSerializer) -> kryo.addDefaultSerializer(clazz, customSerializer) }
private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T { private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T {
return getPool(this).run { kryo -> return getPool(this).run { kryo ->
kryo.context.ensureCapacity(properties.size) kryo.context.ensureCapacity(properties.size)

View File

@ -0,0 +1,99 @@
package net.corda.node.customcheckpointserializer
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.generateKeyPair
import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.coretesting.internal.rigorousMock
import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.serialization.internal.CordaSerializationEncoding
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class CustomCheckpointSerializerTest(private val compression: CordaSerializationEncoding?) {
companion object {
@Parameterized.Parameters(name = "{0}")
@JvmStatic
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
@get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true)
private val context: CheckpointSerializationContext = CheckpointSerializationContextImpl(
deserializationClassLoader = javaClass.classLoader,
whitelist = AllWhitelist,
properties = emptyMap(),
objectReferencesEnabled = true,
encoding = compression,
encodingWhitelist = rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
},
checkpointCustomSerializers = listOf(
TestCorDapp.TestAbstractClassSerializer(),
TestCorDapp.TestClassSerializer(),
TestCorDapp.TestInterfaceSerializer(),
TestCorDapp.TestFinalClassSerializer(),
TestCorDapp.BrokenPublicKeySerializer()
)
)
@Test(timeout=300_000)
fun `test custom checkpoint serialization`() {
testBrokenMapSerialization(DifficultToSerialize.BrokenMapClass())
}
@Test(timeout=300_000)
fun `test custom checkpoint serialization using interface`() {
testBrokenMapSerialization(DifficultToSerialize.BrokenMapInterfaceImpl())
}
@Test(timeout=300_000)
fun `test custom checkpoint serialization using abstract class`() {
testBrokenMapSerialization(DifficultToSerialize.BrokenMapAbstractImpl())
}
@Test(timeout=300_000)
fun `test custom checkpoint serialization using final class`() {
testBrokenMapSerialization(DifficultToSerialize.BrokenMapFinal())
}
@Test(timeout=300_000)
fun `test PublicKey serializer has not been overridden`() {
val publicKey = generateKeyPair().public
// Serialize/deserialize
val checkpoint = publicKey.checkpointSerialize(context)
val deserializedCheckpoint = checkpoint.checkpointDeserialize(context)
// Check the elements are as expected
Assert.assertArrayEquals(publicKey.encoded, deserializedCheckpoint.encoded)
}
private fun testBrokenMapSerialization(brokenMap : MutableMap<String, String>): MutableMap<String, String> {
// Add elements to the map
brokenMap.putAll(mapOf("key" to "value"))
// Serialize/deserialize
val checkpoint = brokenMap.checkpointSerialize(context)
val deserializedCheckpoint = checkpoint.checkpointDeserialize(context)
// Check the elements are as expected
Assert.assertEquals(1, deserializedCheckpoint.size)
Assert.assertEquals("value", deserializedCheckpoint.get("key"))
// Return map for extra checks
return deserializedCheckpoint
}
}

View File

@ -0,0 +1,27 @@
package net.corda.node.customcheckpointserializer
import net.corda.core.flows.FlowException
class DifficultToSerialize {
// Broken Map
// This map breaks the rules for the put method. Making the normal map serializer fail.
open class BrokenMapBaseImpl<K,V>(delegate: MutableMap<K, V> = mutableMapOf()) : MutableMap<K,V> by delegate {
override fun put(key: K, value: V): V? = throw FlowException("Broken on purpose")
}
// A class to test custom serializers applied to implementations
class BrokenMapClass<K,V> : BrokenMapBaseImpl<K, V>()
// An interface and implementation to test custom serializers applied to interface types
interface BrokenMapInterface<K, V> : MutableMap<K, V>
class BrokenMapInterfaceImpl<K,V> : BrokenMapBaseImpl<K, V>(), BrokenMapInterface<K, V>
// An abstract class and implementation to test custom serializers applied to interface types
abstract class BrokenMapAbstract<K, V> : BrokenMapBaseImpl<K, V>(), MutableMap<K, V>
class BrokenMapAbstractImpl<K,V> : BrokenMapAbstract<K, V>()
// A final class
final class BrokenMapFinal<K, V>: BrokenMapBaseImpl<K, V>()
}

View File

@ -0,0 +1,59 @@
package net.corda.node.customcheckpointserializer
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.utilities.getOrThrow
import net.corda.node.logging.logFile
import net.corda.testing.driver.driver
import org.assertj.core.api.Assertions
import org.junit.Test
import java.time.Duration
class DuplicateSerializerLogTest{
@Test(timeout=300_000)
fun `check duplicate serialisers are logged`() {
driver {
val node = startNode(startInSameProcess = false).getOrThrow()
node.rpc.startFlow(::TestFlow).returnValue.get()
val text = node.logFile().readLines().filter { it.startsWith("[WARN") }
// Initial message is correct
Assertions.assertThat(text).anyMatch {it.contains("Duplicate custom checkpoint serializer for type net.corda.node.customcheckpointserializer.DifficultToSerialize\$BrokenMapInterface<java.lang.Object, java.lang.Object>. Serializers: ")}
// Message mentions TestInterfaceSerializer
Assertions.assertThat(text).anyMatch {it.contains("net.corda.node.customcheckpointserializer.TestCorDapp\$TestInterfaceSerializer")}
// Message mentions DuplicateSerializer
Assertions.assertThat(text).anyMatch {it.contains("net.corda.node.customcheckpointserializer.DuplicateSerializerLogTest\$DuplicateSerializer")}
}
}
@StartableByRPC
@InitiatingFlow
class TestFlow : FlowLogic<DifficultToSerialize.BrokenMapInterface<String, String>>() {
override fun call(): DifficultToSerialize.BrokenMapInterface<String, String> {
val brokenMap: DifficultToSerialize.BrokenMapInterface<String, String> = DifficultToSerialize.BrokenMapInterfaceImpl()
brokenMap.putAll(mapOf("test" to "input"))
sleep(Duration.ofSeconds(0))
return brokenMap
}
}
@Suppress("unused")
class DuplicateSerializer :
CheckpointCustomSerializer<DifficultToSerialize.BrokenMapInterface<Any, Any>, HashMap<Any, Any>> {
override fun toProxy(obj: DifficultToSerialize.BrokenMapInterface<Any, Any>): HashMap<Any, Any> {
val proxy = HashMap<Any, Any>()
return obj.toMap(proxy)
}
override fun fromProxy(proxy: HashMap<Any, Any>): DifficultToSerialize.BrokenMapInterface<Any, Any> {
return DifficultToSerialize.BrokenMapInterfaceImpl<Any, Any>()
.also { it.putAll(proxy) }
}
}
}

View File

@ -0,0 +1,58 @@
package net.corda.node.customcheckpointserializer
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.getOrThrow
import net.corda.node.logging.logFile
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.internal.enclosedCordapp
import org.assertj.core.api.Assertions
import org.junit.Test
import java.time.Duration
class DuplicateSerializerLogWithSameSerializerTest {
@Test(timeout=300_000)
fun `check duplicate serialisers are logged not logged for the same class`() {
// Duplicate the cordapp in this node
driver(DriverParameters(cordappsForAllNodes = listOf(this.enclosedCordapp(), this.enclosedCordapp()))) {
val node = startNode(startInSameProcess = false).getOrThrow()
node.rpc.startFlow(::TestFlow).returnValue.get()
val text = node.logFile().readLines().filter { it.startsWith("[WARN") }
// Initial message is not logged
Assertions.assertThat(text)
.anyMatch { !it.contains("Duplicate custom checkpoint serializer for type ") }
// Log does not mention DuplicateSerializerThatShouldNotBeLogged
Assertions.assertThat(text)
.anyMatch { !it.contains("DuplicateSerializerThatShouldNotBeLogged") }
}
}
@CordaSerializable
class UnusedClass
@Suppress("unused")
class DuplicateSerializerThatShouldNotBeLogged : CheckpointCustomSerializer<UnusedClass, String> {
override fun toProxy(obj: UnusedClass): String = ""
override fun fromProxy(proxy: String): UnusedClass = UnusedClass()
}
@StartableByRPC
@InitiatingFlow
class TestFlow : FlowLogic<UnusedClass>() {
override fun call(): UnusedClass {
val unusedClass = UnusedClass()
sleep(Duration.ofSeconds(0))
return unusedClass
}
}
}

View File

@ -0,0 +1,75 @@
package net.corda.node.customcheckpointserializer
import co.paralleluniverse.fibers.Suspendable
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetworkParameters
import org.assertj.core.api.Assertions
import org.junit.After
import org.junit.Before
import org.junit.Test
class MockNetworkCustomCheckpointSerializerTest {
private lateinit var mockNetwork: MockNetwork
@Before
fun setup() {
mockNetwork = MockNetwork(MockNetworkParameters(cordappsForAllNodes = listOf(TestCorDapp.getCorDapp())))
}
@After
fun shutdown() {
mockNetwork.stopNodes()
}
@Test(timeout = 300_000)
fun `flow suspend with custom kryo serializer`() {
val node = mockNetwork.createPartyNode()
val expected = 5
val actual = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariable(5)).get()
Assertions.assertThat(actual).isEqualTo(expected)
}
@Test(timeout = 300_000)
fun `check references are restored correctly`() {
val node = mockNetwork.createPartyNode()
val expectedReference = DifficultToSerialize.BrokenMapClass<String, Int>()
expectedReference.putAll(mapOf("one" to 1))
val actualReference = node.startFlow(TestCorDapp.TestFlowCheckingReferencesWork(expectedReference)).get()
Assertions.assertThat(actualReference).isSameAs(expectedReference)
Assertions.assertThat(actualReference["one"]).isEqualTo(1)
}
@Test(timeout = 300_000)
@Suspendable
fun `check serialization of interfaces`() {
val node = mockNetwork.createPartyNode()
val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsInterface(5)).get()
Assertions.assertThat(result).isEqualTo(5)
}
@Test(timeout = 300_000)
@Suspendable
fun `check serialization of abstract classes`() {
val node = mockNetwork.createPartyNode()
val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsAbstract(5)).get()
Assertions.assertThat(result).isEqualTo(5)
}
@Test(timeout = 300_000)
@Suspendable
fun `check serialization of final classes`() {
val node = mockNetwork.createPartyNode()
val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsFinal(5)).get()
Assertions.assertThat(result).isEqualTo(5)
}
@Test(timeout = 300_000)
@Suspendable
fun `check PublicKey serializer has not been overridden`() {
val node = mockNetwork.createPartyNode()
val result = node.startFlow(TestCorDapp.TestFlowCheckingPublicKeySerializer()).get()
Assertions.assertThat(result.encoded).isEqualTo(node.info.legalIdentities.first().owningKey.encoded)
}
}

View File

@ -0,0 +1,75 @@
package net.corda.node.customcheckpointserializer
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.coretesting.internal.rigorousMock
import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.serialization.internal.CordaSerializationEncoding
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class ReferenceLoopTest(private val compression: CordaSerializationEncoding?) {
companion object {
@Parameterized.Parameters(name = "{0}")
@JvmStatic
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
@get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true)
private val context: CheckpointSerializationContext = CheckpointSerializationContextImpl(
deserializationClassLoader = javaClass.classLoader,
whitelist = AllWhitelist,
properties = emptyMap(),
objectReferencesEnabled = true,
encoding = compression,
encodingWhitelist = rigorousMock<EncodingWhitelist>()
.also {
if (compression != null) doReturn(true).whenever(it)
.acceptEncoding(compression)
},
checkpointCustomSerializers = listOf(PersonSerializer()))
@Test(timeout=300_000)
fun `custom checkpoint serialization with reference loop`() {
val person = Person("Test name")
val result = person.checkpointSerialize(context).checkpointDeserialize(context)
Assert.assertEquals("Test name", result.name)
Assert.assertEquals("Test name", result.bestFriend.name)
Assert.assertSame(result, result.bestFriend)
}
/**
* Test class that will hold a reference to itself
*/
class Person(val name: String, bestFriend: Person? = null) {
val bestFriend: Person = bestFriend ?: this
}
/**
* Custom serializer for the Person class
*/
@Suppress("unused")
class PersonSerializer : CheckpointCustomSerializer<Person, Map<String, Any>> {
override fun toProxy(obj: Person): Map<String, Any> {
return mapOf("name" to obj.name, "bestFriend" to obj.bestFriend)
}
override fun fromProxy(proxy: Map<String, Any>): Person {
return Person(proxy["name"] as String, proxy["bestFriend"] as Person?)
}
}
}

View File

@ -0,0 +1,214 @@
package net.corda.node.customcheckpointserializer
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.testing.node.internal.CustomCordapp
import net.corda.testing.node.internal.enclosedCordapp
import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.assertj.core.api.Assertions
import java.security.PublicKey
import java.time.Duration
/**
* Contains all the flows and custom serializers for testing custom checkpoint serializers
*/
class TestCorDapp {
companion object {
fun getCorDapp(): CustomCordapp = enclosedCordapp()
}
// Flows
@StartableByRPC
class TestFlowWithDifficultToSerializeLocalVariableAsAbstract(private val purchase: Int) : FlowLogic<Int>() {
@Suspendable
override fun call(): Int {
// This object is difficult to serialize with Kryo
val difficultToSerialize: DifficultToSerialize.BrokenMapAbstract<String, Int> = DifficultToSerialize.BrokenMapAbstractImpl()
difficultToSerialize.putAll(mapOf("foo" to purchase))
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Return value from deserialized object
return difficultToSerialize["foo"] ?: 0
}
}
@StartableByRPC
class TestFlowWithDifficultToSerializeLocalVariableAsFinal(private val purchase: Int) : FlowLogic<Int>() {
@Suspendable
override fun call(): Int {
// This object is difficult to serialize with Kryo
val difficultToSerialize: DifficultToSerialize.BrokenMapFinal<String, Int> = DifficultToSerialize.BrokenMapFinal()
difficultToSerialize.putAll(mapOf("foo" to purchase))
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Return value from deserialized object
return difficultToSerialize["foo"] ?: 0
}
}
@StartableByRPC
class TestFlowWithDifficultToSerializeLocalVariableAsInterface(private val purchase: Int) : FlowLogic<Int>() {
@Suspendable
override fun call(): Int {
// This object is difficult to serialize with Kryo
val difficultToSerialize: DifficultToSerialize.BrokenMapInterface<String, Int> = DifficultToSerialize.BrokenMapInterfaceImpl()
difficultToSerialize.putAll(mapOf("foo" to purchase))
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Return value from deserialized object
return difficultToSerialize["foo"] ?: 0
}
}
@StartableByRPC
class TestFlowWithDifficultToSerializeLocalVariable(private val purchase: Int) : FlowLogic<Int>() {
@Suspendable
override fun call(): Int {
// This object is difficult to serialize with Kryo
val difficultToSerialize: DifficultToSerialize.BrokenMapClass<String, Int> = DifficultToSerialize.BrokenMapClass()
difficultToSerialize.putAll(mapOf("foo" to purchase))
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Return value from deserialized object
return difficultToSerialize["foo"] ?: 0
}
}
@StartableByRPC
class TestFlowCheckingReferencesWork(private val reference: DifficultToSerialize.BrokenMapClass<String, Int>) :
FlowLogic<DifficultToSerialize.BrokenMapClass<String, Int>>() {
private val referenceField = reference
@Suspendable
override fun call(): DifficultToSerialize.BrokenMapClass<String, Int> {
val ref = referenceField
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Check all objects refer to same object
Assertions.assertThat(reference).isSameAs(referenceField)
Assertions.assertThat(referenceField).isSameAs(ref)
// Return deserialized object
return ref
}
}
@StartableByRPC
class TestFlowCheckingPublicKeySerializer :
FlowLogic<PublicKey>() {
@Suspendable
override fun call(): PublicKey {
val ref = ourIdentity.owningKey
// Force a checkpoint
sleep(Duration.ofSeconds(0))
// Return deserialized object
return ref
}
}
// Custom serializers
@Suppress("unused")
class TestInterfaceSerializer :
CheckpointCustomSerializer<DifficultToSerialize.BrokenMapInterface<Any, Any>, HashMap<Any, Any>> {
override fun toProxy(obj: DifficultToSerialize.BrokenMapInterface<Any, Any>): HashMap<Any, Any> {
val proxy = HashMap<Any, Any>()
return obj.toMap(proxy)
}
override fun fromProxy(proxy: HashMap<Any, Any>): DifficultToSerialize.BrokenMapInterface<Any, Any> {
return DifficultToSerialize.BrokenMapInterfaceImpl<Any, Any>()
.also { it.putAll(proxy) }
}
}
@Suppress("unused")
class TestClassSerializer :
CheckpointCustomSerializer<DifficultToSerialize.BrokenMapClass<Any, Any>, HashMap<Any, Any>> {
override fun toProxy(obj: DifficultToSerialize.BrokenMapClass<Any, Any>): HashMap<Any, Any> {
val proxy = HashMap<Any, Any>()
return obj.toMap(proxy)
}
override fun fromProxy(proxy: HashMap<Any, Any>): DifficultToSerialize.BrokenMapClass<Any, Any> {
return DifficultToSerialize.BrokenMapClass<Any, Any>()
.also { it.putAll(proxy) }
}
}
@Suppress("unused")
class TestAbstractClassSerializer :
CheckpointCustomSerializer<DifficultToSerialize.BrokenMapAbstract<Any, Any>, HashMap<Any, Any>> {
override fun toProxy(obj: DifficultToSerialize.BrokenMapAbstract<Any, Any>): HashMap<Any, Any> {
val proxy = HashMap<Any, Any>()
return obj.toMap(proxy)
}
override fun fromProxy(proxy: HashMap<Any, Any>): DifficultToSerialize.BrokenMapAbstract<Any, Any> {
return DifficultToSerialize.BrokenMapAbstractImpl<Any, Any>()
.also { it.putAll(proxy) }
}
}
@Suppress("unused")
class TestFinalClassSerializer :
CheckpointCustomSerializer<DifficultToSerialize.BrokenMapFinal<Any, Any>, HashMap<Any, Any>> {
override fun toProxy(obj: DifficultToSerialize.BrokenMapFinal<Any, Any>): HashMap<Any, Any> {
val proxy = HashMap<Any, Any>()
return obj.toMap(proxy)
}
override fun fromProxy(proxy: HashMap<Any, Any>): DifficultToSerialize.BrokenMapFinal<Any, Any> {
return DifficultToSerialize.BrokenMapFinal<Any, Any>()
.also { it.putAll(proxy) }
}
}
@Suppress("unused")
class BrokenPublicKeySerializer :
CheckpointCustomSerializer<PublicKey, String> {
override fun toProxy(obj: PublicKey): String {
throw FlowException("Broken on purpose")
}
override fun fromProxy(proxy: String): PublicKey {
throw FlowException("Broken on purpose")
}
}
@Suppress("unused")
class BrokenEdDSAPublicKeySerializer :
CheckpointCustomSerializer<EdDSAPublicKey, String> {
override fun toProxy(obj: EdDSAPublicKey): String {
throw FlowException("Broken on purpose")
}
override fun fromProxy(proxy: String): EdDSAPublicKey {
throw FlowException("Broken on purpose")
}
}
}

View File

@ -0,0 +1,273 @@
package net.corda.node.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.CordaRuntimeException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.node.services.statemachine.transitions.PrematureSessionCloseException
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.User
import net.corda.testing.node.internal.enclosedCordapp
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.sql.SQLTransientConnectionException
import kotlin.test.assertEquals
class FlowSessionCloseTest {
private val user = User("user", "pwd", setOf(Permissions.all()))
@Test(timeout=300_000)
fun `flow cannot close uninitialised session`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), true, null, false).returnValue.getOrThrow() }
.isInstanceOf(CordaRuntimeException::class.java)
.hasMessageContaining(PrematureSessionCloseException::class.java.name)
.hasMessageContaining("The following session was closed before it was initialised")
}
}
}
@Test(timeout=300_000)
fun `flow cannot access closed session`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
).transpose().getOrThrow()
InitiatorFlow.SessionAPI.values().forEach { sessionAPI ->
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, sessionAPI, false).returnValue.getOrThrow() }
.isInstanceOf(UnexpectedFlowEndException::class.java)
.hasMessageContaining("Tried to access ended session")
}
}
}
}
@Test(timeout=300_000)
fun `flow can close initialised session successfully`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, false).returnValue.getOrThrow()
}
}
}
@Test(timeout=300_000)
fun `flow can close initialised session successfully even in case of failures and replays`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, true).returnValue.getOrThrow()
}
}
}
@Test(timeout=300_000)
fun `flow can close multiple sessions successfully`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorMultipleSessionsFlow, nodeBHandle.nodeInfo.legalIdentities.first()).returnValue.getOrThrow()
}
}
}
/**
* This test ensures that when sessions are closed, the associated resources are eagerly cleaned up.
* If sessions are not closed, then the node will crash with an out-of-memory error.
* This can be confirmed by commenting out [FlowSession.close] operation in the invoked flow and re-run the test.
*/
@Test(timeout=300_000)
fun `flow looping over sessions can close them to release resources and avoid out-of-memory failures, when the other side does not finish early`() {
driver(DriverParameters(startNodesInProcess = false, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), maximumHeapSize = "256m"),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user), maximumHeapSize = "256m")
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorLoopingFlow, nodeBHandle.nodeInfo.legalIdentities.first(), true).returnValue.getOrThrow()
}
}
}
@Test(timeout=300_000)
fun `flow looping over sessions will close sessions automatically, when the other side finishes early`() {
driver(DriverParameters(startNodesInProcess = false, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) {
val (nodeAHandle, nodeBHandle) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), maximumHeapSize = "256m"),
startNode(providedName = BOB_NAME, rpcUsers = listOf(user), maximumHeapSize = "256m")
).transpose().getOrThrow()
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use {
it.proxy.startFlow(::InitiatorLoopingFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false).returnValue.getOrThrow()
}
}
}
@InitiatingFlow
@StartableByRPC
class InitiatorFlow(val party: Party, private val prematureClose: Boolean = false,
private val accessClosedSessionWithApi: SessionAPI? = null,
private val retryClose: Boolean = false): FlowLogic<Unit>() {
@CordaSerializable
enum class SessionAPI {
SEND,
SEND_AND_RECEIVE,
RECEIVE,
GET_FLOW_INFO
}
@Suspendable
override fun call() {
val session = initiateFlow(party)
if (prematureClose) {
session.close()
}
session.send(retryClose)
sleep(1.seconds)
if (accessClosedSessionWithApi != null) {
when(accessClosedSessionWithApi) {
SessionAPI.SEND -> session.send("dummy payload ")
SessionAPI.RECEIVE -> session.receive<String>()
SessionAPI.SEND_AND_RECEIVE -> session.sendAndReceive<String>("dummy payload")
SessionAPI.GET_FLOW_INFO -> session.getCounterpartyFlowInfo()
}
}
}
}
@InitiatedBy(InitiatorFlow::class)
class InitiatedFlow(private val otherSideSession: FlowSession): FlowLogic<Unit>() {
companion object {
var thrown = false
}
@Suspendable
override fun call() {
val retryClose = otherSideSession.receive<Boolean>()
.unwrap{ it }
otherSideSession.close()
// failing with a transient exception to force a replay of the close.
if (retryClose) {
if (!thrown) {
thrown = true
throw SQLTransientConnectionException("Connection is not available")
}
}
}
}
@InitiatingFlow
@StartableByRPC
class InitiatorLoopingFlow(val party: Party, val blockingCounterparty: Boolean = false): FlowLogic<Unit>() {
@Suspendable
override fun call() {
for (i in 1..1_000) {
val session = initiateFlow(party)
session.sendAndReceive<String>(blockingCounterparty ).unwrap{ assertEquals("Got it", it) }
/**
* If the counterparty blocks, we need to eagerly close the session and release resources to avoid running out of memory.
* Otherwise, the session end messages from the other side will do that automatically.
*/
if (blockingCounterparty) {
session.close()
}
logger.info("Completed iteration $i")
}
}
}
@InitiatedBy(InitiatorLoopingFlow::class)
class InitiatedLoopingFlow(private val otherSideSession: FlowSession): FlowLogic<Unit>() {
@Suspendable
override fun call() {
val shouldBlock = otherSideSession.receive<Boolean>()
.unwrap{ it }
otherSideSession.send("Got it")
if (shouldBlock) {
otherSideSession.receive<String>()
}
}
}
@InitiatingFlow
@StartableByRPC
class InitiatorMultipleSessionsFlow(val party: Party): FlowLogic<Unit>() {
@Suspendable
override fun call() {
for (round in 1 .. 2) {
val sessions = mutableListOf<FlowSession>()
for (session_number in 1 .. 5) {
val session = initiateFlow(party)
sessions.add(session)
session.sendAndReceive<String>("What's up?").unwrap{ assertEquals("All good!", it) }
}
close(sessions.toNonEmptySet())
}
}
}
@InitiatedBy(InitiatorMultipleSessionsFlow::class)
class InitiatedMultipleSessionsFlow(private val otherSideSession: FlowSession): FlowLogic<Unit>() {
@Suspendable
override fun call() {
otherSideSession.receive<String>()
.unwrap{ assertEquals("What's up?", it) }
otherSideSession.send("All good!")
}
}
}

View File

@ -656,8 +656,8 @@ open class Node(configuration: NodeConfiguration,
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointSerializer = KryoCheckpointSerializer, checkpointSerializer = KryoCheckpointSerializer,
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader).withCheckpointCustomSerializers(cordappLoader.cordapps.flatMap { it.checkpointCustomSerializers })
) )
} }
/** Starts a blocking event loop for message dispatch. */ /** Starts a blocking event loop for message dispatch. */

View File

@ -18,6 +18,7 @@ import net.corda.core.internal.notary.NotaryService
import net.corda.core.internal.notary.SinglePartyNotaryService import net.corda.core.internal.notary.SinglePartyNotaryService
import net.corda.core.node.services.CordaService import net.corda.core.node.services.CordaService
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
@ -185,6 +186,7 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
findServices(this), findServices(this),
findWhitelists(url), findWhitelists(url),
findSerializers(this), findSerializers(this),
findCheckpointSerializers(this),
findCustomSchemas(this), findCustomSchemas(this),
findAllFlows(this), findAllFlows(this),
url.url, url.url,
@ -334,6 +336,10 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
return scanResult.getClassesImplementingWithClassVersionCheck(SerializationCustomSerializer::class) return scanResult.getClassesImplementingWithClassVersionCheck(SerializationCustomSerializer::class)
} }
private fun findCheckpointSerializers(scanResult: RestrictedScanResult): List<CheckpointCustomSerializer<*, *>> {
return scanResult.getClassesImplementingWithClassVersionCheck(CheckpointCustomSerializer::class)
}
private fun findCustomSchemas(scanResult: RestrictedScanResult): Set<MappedSchema> { private fun findCustomSchemas(scanResult: RestrictedScanResult): Set<MappedSchema> {
return scanResult.getClassesWithSuperclass(MappedSchema::class).instances().toSet() return scanResult.getClassesWithSuperclass(MappedSchema::class).instances().toSet()
} }

View File

@ -32,6 +32,7 @@ internal object VirtualCordapp {
services = listOf(), services = listOf(),
serializationWhitelists = listOf(), serializationWhitelists = listOf(),
serializationCustomSerializers = listOf(), serializationCustomSerializers = listOf(),
checkpointCustomSerializers = listOf(),
customSchemas = setOf(), customSchemas = setOf(),
info = Cordapp.Info.Default("corda-core", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), info = Cordapp.Info.Default("corda-core", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"),
allFlows = listOf(), allFlows = listOf(),
@ -55,6 +56,7 @@ internal object VirtualCordapp {
services = listOf(), services = listOf(),
serializationWhitelists = listOf(), serializationWhitelists = listOf(),
serializationCustomSerializers = listOf(), serializationCustomSerializers = listOf(),
checkpointCustomSerializers = listOf(),
customSchemas = setOf(NodeNotarySchemaV1), customSchemas = setOf(NodeNotarySchemaV1),
info = Cordapp.Info.Default("corda-notary", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), info = Cordapp.Info.Default("corda-notary", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"),
allFlows = listOf(), allFlows = listOf(),
@ -78,6 +80,7 @@ internal object VirtualCordapp {
services = listOf(), services = listOf(),
serializationWhitelists = listOf(), serializationWhitelists = listOf(),
serializationCustomSerializers = listOf(), serializationCustomSerializers = listOf(),
checkpointCustomSerializers = listOf(),
customSchemas = setOf(RaftNotarySchemaV1), customSchemas = setOf(RaftNotarySchemaV1),
info = Cordapp.Info.Default("corda-notary-raft", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), info = Cordapp.Info.Default("corda-notary-raft", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"),
allFlows = listOf(), allFlows = listOf(),
@ -101,6 +104,7 @@ internal object VirtualCordapp {
services = listOf(), services = listOf(),
serializationWhitelists = listOf(), serializationWhitelists = listOf(),
serializationCustomSerializers = listOf(), serializationCustomSerializers = listOf(),
checkpointCustomSerializers = listOf(),
customSchemas = setOf(BFTSmartNotarySchemaV1), customSchemas = setOf(BFTSmartNotarySchemaV1),
info = Cordapp.Info.Default("corda-notary-bft-smart", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), info = Cordapp.Info.Default("corda-notary-bft-smart", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"),
allFlows = listOf(), allFlows = listOf(),

View File

@ -121,6 +121,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
var bridgeSession: ClientSession? = null var bridgeSession: ClientSession? = null
var bridgeNotifyConsumer: ClientConsumer? = null var bridgeNotifyConsumer: ClientConsumer? = null
var networkChangeSubscription: Subscription? = null var networkChangeSubscription: Subscription? = null
var sessionFactory: ClientSessionFactory? = null
fun sendMessage(address: String, message: ClientMessage) = producer!!.send(address, message) fun sendMessage(address: String, message: ClientMessage) = producer!!.send(address, message)
} }
@ -172,7 +173,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
minLargeMessageSize = maxMessageSize + JOURNAL_HEADER_SIZE minLargeMessageSize = maxMessageSize + JOURNAL_HEADER_SIZE
isUseGlobalPools = nodeSerializationEnv != null isUseGlobalPools = nodeSerializationEnv != null
} }
val sessionFactory = locator!!.createSessionFactory().addFailoverListener(::failoverCallback) sessionFactory = locator!!.createSessionFactory().addFailoverListener(::failoverCallback)
// Login using the node username. The broker will authenticate us as its node (as opposed to another peer) // Login using the node username. The broker will authenticate us as its node (as opposed to another peer)
// using our TLS certificate. // using our TLS certificate.
// Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer // Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer
@ -490,8 +491,10 @@ class P2PMessagingClient(val config: NodeConfiguration,
// Wait for the main loop to notice the consumer has gone and finish up. // Wait for the main loop to notice the consumer has gone and finish up.
shutdownLatch.await() shutdownLatch.await()
} }
// Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests. // Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests.
state.locked { state.locked {
sessionFactory?.close()
locator?.close() locator?.close()
} }
} }

View File

@ -60,13 +60,11 @@ import net.corda.nodeapi.internal.lifecycle.NodeLifecycleObserver.Companion.repo
import net.corda.node.internal.NodeStartup import net.corda.node.internal.NodeStartup
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.ErrorState import net.corda.node.services.statemachine.ErrorState
import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.ExistingSessionMessagePayload
import net.corda.node.services.statemachine.FlowSessionImpl import net.corda.node.services.statemachine.FlowSessionImpl
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.SubFlow import net.corda.node.services.statemachine.SubFlow
@ -325,6 +323,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
val send: List<SendJson>? = null, val send: List<SendJson>? = null,
val receive: NonEmptySet<FlowSession>? = null, val receive: NonEmptySet<FlowSession>? = null,
val sendAndReceive: List<SendJson>? = null, val sendAndReceive: List<SendJson>? = null,
val closeSessions: NonEmptySet<FlowSession>? = null,
val waitForLedgerCommit: SecureHash? = null, val waitForLedgerCommit: SecureHash? = null,
val waitForStateConsumption: Set<StateRef>? = null, val waitForStateConsumption: Set<StateRef>? = null,
val getFlowInfo: NonEmptySet<FlowSession>? = null, val getFlowInfo: NonEmptySet<FlowSession>? = null,
@ -352,6 +351,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
is FlowIORequest.Send -> SuspendedOn(send = sessionToMessage.toJson()) is FlowIORequest.Send -> SuspendedOn(send = sessionToMessage.toJson())
is FlowIORequest.Receive -> SuspendedOn(receive = sessions) is FlowIORequest.Receive -> SuspendedOn(receive = sessions)
is FlowIORequest.SendAndReceive -> SuspendedOn(sendAndReceive = sessionToMessage.toJson()) is FlowIORequest.SendAndReceive -> SuspendedOn(sendAndReceive = sessionToMessage.toJson())
is FlowIORequest.CloseSessions -> SuspendedOn(closeSessions = sessions)
is FlowIORequest.WaitForLedgerCommit -> SuspendedOn(waitForLedgerCommit = hash) is FlowIORequest.WaitForLedgerCommit -> SuspendedOn(waitForLedgerCommit = hash)
is FlowIORequest.GetFlowInfo -> SuspendedOn(getFlowInfo = sessions) is FlowIORequest.GetFlowInfo -> SuspendedOn(getFlowInfo = sessions)
is FlowIORequest.Sleep -> SuspendedOn(sleepTill = wakeUpAfter) is FlowIORequest.Sleep -> SuspendedOn(sleepTill = wakeUpAfter)
@ -379,16 +379,14 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
private class ActiveSession( private class ActiveSession(
val peer: Party, val peer: Party,
val ourSessionId: SessionId, val ourSessionId: SessionId,
val receivedMessages: List<DataSessionMessage>, val receivedMessages: List<ExistingSessionMessagePayload>,
val errors: List<FlowError>,
val peerFlowInfo: FlowInfo, val peerFlowInfo: FlowInfo,
val peerSessionId: SessionId? val peerSessionId: SessionId?
) )
private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? { private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? {
return if (this is SessionState.Initiated) { return if (this is SessionState.Initiated) {
val peerSessionId = (initiatedState as? InitiatedSessionState.Live)?.peerSinkSessionId ActiveSession(peerParty, sessionId, receivedMessages, peerFlowInfo, peerSinkSessionId)
ActiveSession(peerParty, sessionId, receivedMessages, errors, peerFlowInfo, peerSessionId)
} else { } else {
null null
} }

View File

@ -130,13 +130,9 @@ internal class ActionExecutorImpl(
log.warn("Propagating error", exception) log.warn("Propagating error", exception)
} }
for (sessionState in action.sessions) { for (sessionState in action.sessions) {
// We cannot propagate if the session isn't live.
if (sessionState.initiatedState !is InitiatedSessionState.Live) {
continue
}
// Don't propagate errors to the originating session // Don't propagate errors to the originating session
for (errorMessage in action.errorMessages) { for (errorMessage in action.errorMessages) {
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId val sinkSessionId = sessionState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID)) flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))

View File

@ -69,11 +69,11 @@ class FlowCreator(
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE) val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
fiber.transientValues = TransientReference(createTransientValues(runId, resultFuture))
fiber.logic.stateMachine = fiber fiber.logic.stateMachine = fiber
verifyFlowLogicIsSuspendable(fiber.logic) verifyFlowLogicIsSuspendable(fiber.logic)
val state = createStateMachineState(checkpoint, fiber, true) val state = createStateMachineState(checkpoint, fiber, true)
fiber.transientState = TransientReference(state) fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = state
return Flow(fiber, resultFuture) return Flow(fiber, resultFuture)
} }
@ -91,7 +91,7 @@ class FlowCreator(
// have access to the fiber (and thereby the service hub) // have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) flowStateMachineImpl.transientValues = createTransientValues(flowId, resultFuture)
flowLogic.stateMachine = flowStateMachineImpl flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext) val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext)
val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion( val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(
@ -113,7 +113,7 @@ class FlowCreator(
existingCheckpoint != null, existingCheckpoint != null,
deduplicationHandler, deduplicationHandler,
senderUUID) senderUUID)
flowStateMachineImpl.transientState = TransientReference(state) flowStateMachineImpl.transientState = state
return Flow(flowStateMachineImpl, resultFuture) return Flow(flowStateMachineImpl, resultFuture)
} }

View File

@ -39,18 +39,14 @@ class FlowDefaultUncaughtExceptionHandler(
val id = fiber.id val id = fiber.id
if (!fiber.resultFuture.isDone) { if (!fiber.resultFuture.isDone) {
fiber.transientState.let { state -> fiber.transientState.let { state ->
if (state != null) { fiber.logger.warn("Forcing flow $id into overnight observation")
fiber.logger.warn("Forcing flow $id into overnight observation") flowHospital.forceIntoOvernightObservation(state, listOf(throwable))
flowHospital.forceIntoOvernightObservation(state.value, listOf(throwable)) val hospitalizedCheckpoint = state.checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED)
val hospitalizedCheckpoint = state.value.checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED) val hospitalizedState = state.copy(checkpoint = hospitalizedCheckpoint)
val hospitalizedState = state.value.copy(checkpoint = hospitalizedCheckpoint) fiber.transientState = hospitalizedState
fiber.transientState = TransientReference(hospitalizedState)
} else {
fiber.logger.warn("The fiber's transient state is not set, cannot force flow $id into in-memory overnight observation, status will still be updated in database")
}
} }
scheduledExecutor.schedule({ setFlowToHospitalizedRescheduleOnFailure(id) }, 0, TimeUnit.SECONDS)
} }
scheduledExecutor.schedule({ setFlowToHospitalizedRescheduleOnFailure(id) }, 0, TimeUnit.SECONDS)
} }
@Suppress("TooGenericExceptionCaught") @Suppress("TooGenericExceptionCaught")

View File

@ -78,6 +78,7 @@ internal class FlowMonitor(
is FlowIORequest.Send -> "to send a message to parties ${request.sessionToMessage.keys.partiesInvolved()}" is FlowIORequest.Send -> "to send a message to parties ${request.sessionToMessage.keys.partiesInvolved()}"
is FlowIORequest.Receive -> "to receive messages from parties ${request.sessions.partiesInvolved()}" is FlowIORequest.Receive -> "to receive messages from parties ${request.sessions.partiesInvolved()}"
is FlowIORequest.SendAndReceive -> "to send and receive messages from parties ${request.sessionToMessage.keys.partiesInvolved()}" is FlowIORequest.SendAndReceive -> "to send and receive messages from parties ${request.sessionToMessage.keys.partiesInvolved()}"
is FlowIORequest.CloseSessions -> "to close sessions: ${request.sessions}"
is FlowIORequest.WaitForLedgerCommit -> "for the ledger to commit transaction with hash ${request.hash}" is FlowIORequest.WaitForLedgerCommit -> "for the ledger to commit transaction with hash ${request.hash}"
is FlowIORequest.GetFlowInfo -> "to get flow information from parties ${request.sessions.partiesInvolved()}" is FlowIORequest.GetFlowInfo -> "to get flow information from parties ${request.sessions.partiesInvolved()}"
is FlowIORequest.Sleep -> "to wake up from sleep ending at ${LocalDateTime.ofInstant(request.wakeUpAfter, ZoneId.systemDefault())}" is FlowIORequest.Sleep -> "to wake up from sleep ending at ${LocalDateTime.ofInstant(request.wakeUpAfter, ZoneId.systemDefault())}"
@ -95,12 +96,12 @@ internal class FlowMonitor(
private fun FlowStateMachineImpl<*>.ioRequest() = (snapshot().checkpoint.flowState as? FlowState.Started)?.flowIORequest private fun FlowStateMachineImpl<*>.ioRequest() = (snapshot().checkpoint.flowState as? FlowState.Started)?.flowIORequest
private fun FlowStateMachineImpl<*>.ongoingDuration(now: Instant): Duration { private fun FlowStateMachineImpl<*>.ongoingDuration(now: Instant): Duration {
return transientState?.value?.checkpoint?.timestamp?.let { Duration.between(it, now) } ?: Duration.ZERO return transientState.checkpoint.timestamp.let { Duration.between(it, now) } ?: Duration.ZERO
} }
private fun FlowStateMachineImpl<*>.isSuspended() = !snapshot().isFlowResumed private fun FlowStateMachineImpl<*>.isSuspended() = !snapshot().isFlowResumed
private fun FlowStateMachineImpl<*>.isStarted() = transientState?.value?.checkpoint?.flowState is FlowState.Started private fun FlowStateMachineImpl<*>.isStarted() = transientState.checkpoint.flowState is FlowState.Started
private operator fun StaffedFlowHospital.contains(flow: FlowStateMachine<*>) = contains(flow.id) private operator fun StaffedFlowHospital.contains(flow: FlowStateMachine<*>) = contains(flow.id)
} }

View File

@ -81,6 +81,12 @@ class FlowSessionImpl(
@Suspendable @Suspendable
override fun send(payload: Any) = send(payload, maySkipCheckpoint = false) override fun send(payload: Any) = send(payload, maySkipCheckpoint = false)
@Suspendable
override fun close() {
val request = FlowIORequest.CloseSessions(NonEmptySet.of(this))
return flowStateMachine.suspend(request, false)
}
private fun enforceNotPrimitive(type: Class<*>) { private fun enforceNotPrimitive(type: Class<*>) {
require(!type.isPrimitive) { "Cannot receive primitive type $type" } require(!type.isPrimitive) { "Cannot receive primitive type $type" }
} }

View File

@ -6,6 +6,10 @@ import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import co.paralleluniverse.strands.channels.Channel import co.paralleluniverse.strands.channels.Channel
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
@ -58,7 +62,6 @@ import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.slf4j.MDC import org.slf4j.MDC
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty1
class FlowPermissionException(message: String) : FlowException(message) class FlowPermissionException(message: String) : FlowException(message)
@ -86,52 +89,65 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null) private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null)
} }
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
data class TransientValues( data class TransientValues(
val eventQueue: Channel<Event>, val eventQueue: Channel<Event>,
val resultFuture: CordaFuture<Any?>, val resultFuture: CordaFuture<Any?>,
val database: CordaPersistence, val database: CordaPersistence,
val transitionExecutor: TransitionExecutor, val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor, val actionExecutor: ActionExecutor,
val stateMachine: StateMachine, val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal, val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: CheckpointSerializationContext, val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch, val unfinishedFibers: ReusableLatch,
val waitTimeUpdateHook: (id: StateMachineRunId, timeout: Long) -> Unit val waitTimeUpdateHook: (id: StateMachineRunId, timeout: Long) -> Unit
) ) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${TransientValues::class.qualifiedName} should never be serialized")
}
internal var transientValues: TransientReference<TransientValues>? = null override fun read(kryo: Kryo?, input: Input?) {
internal var transientState: TransientReference<StateMachineState>? = null throw IllegalStateException("${TransientValues::class.qualifiedName} should never be deserialized")
}
/**
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
*/
override val ourSenderUUID: String?
get() = transientState?.value?.senderUUID
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
return field.get(suppliedValues.value)
} }
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> { private var transientValuesReference: TransientReference<TransientValues>? = null
val transaction = contextTransaction internal var transientValues: TransientValues
contextTransactionOrNull = null // After the flow has been created, the transient values should never be null
return TransientReference(transaction) get() = transientValuesReference!!.value
} set(values) {
check(transientValuesReference?.value == null) { "The transient values should only be set once when initialising a flow" }
transientValuesReference = TransientReference(values)
}
private var transientStateReference: TransientReference<StateMachineState>? = null
internal var transientState: StateMachineState
// After the flow has been created, the transient state should never be null
get() = transientStateReference!!.value
set(state) {
transientStateReference = TransientReference(state)
}
/** /**
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary. * is not necessary.
*/ */
override val logger = log override val logger = log
override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture))
override val context: InvocationContext get() = transientState!!.value.checkpoint.checkpointState.invocationContext override val instanceId: StateMachineInstanceId get() = StateMachineInstanceId(id, super.getId())
override val ourIdentity: Party get() = transientState!!.value.checkpoint.checkpointState.ourIdentity
override val isKilled: Boolean get() = transientState!!.value.isKilled override val serviceHub: ServiceHubInternal get() = transientValues.serviceHub
override val stateMachine: StateMachine get() = transientValues.stateMachine
override val resultFuture: CordaFuture<R> get() = uncheckedCast(transientValues.resultFuture)
override val context: InvocationContext get() = transientState.checkpoint.checkpointState.invocationContext
override val ourIdentity: Party get() = transientState.checkpoint.checkpointState.ourIdentity
override val isKilled: Boolean get() = transientState.isKilled
/**
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
*/
override val ourSenderUUID: String? get() = transientState.senderUUID
internal val softLockedStates = mutableSetOf<StateRef>() internal val softLockedStates = mutableSetOf<StateRef>()
@ -143,9 +159,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation { private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
setLoggingContext() setLoggingContext()
val stateMachine = getTransientField(TransientValues::stateMachine) val stateMachine = transientValues.stateMachine
val oldState = transientState!!.value val oldState = transientState
val actionExecutor = getTransientField(TransientValues::actionExecutor) val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState) val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor) val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
// Ensure that the next state that is being written to the transient state maintains the [isKilled] flag // Ensure that the next state that is being written to the transient state maintains the [isKilled] flag
@ -153,7 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
if (oldState.isKilled && !newState.isKilled) { if (oldState.isKilled && !newState.isKilled) {
newState.isKilled = true newState.isKilled = true
} }
transientState = TransientReference(newState) transientState = newState
setLoggingContext() setLoggingContext()
return continuation return continuation
} }
@ -171,15 +187,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? { private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? {
checkDbTransaction(isDbTransactionOpenOnEntry) checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor) val transitionExecutor = transientValues.transitionExecutor
val eventQueue = getTransientField(TransientValues::eventQueue) val eventQueue = transientValues.eventQueue
try { try {
eventLoop@ while (true) { eventLoop@ while (true) {
val nextEvent = try { val nextEvent = try {
eventQueue.receive() eventQueue.receive()
} catch (interrupted: InterruptedException) { } catch (interrupted: InterruptedException) {
log.error("Flow interrupted while waiting for events, aborting immediately") log.error("Flow interrupted while waiting for events, aborting immediately")
(transientValues?.value?.resultFuture as? OpenFuture<*>)?.setException(KilledFlowException(id)) (transientValues.resultFuture as? OpenFuture<*>)?.setException(KilledFlowException(id))
abortFiber() abortFiber()
} }
val continuation = processEvent(transitionExecutor, nextEvent) val continuation = processEvent(transitionExecutor, nextEvent)
@ -246,7 +262,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnEntry: Boolean,
isDbTransactionOpenOnExit: Boolean): FlowContinuation { isDbTransactionOpenOnExit: Boolean): FlowContinuation {
checkDbTransaction(isDbTransactionOpenOnEntry) checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor) val transitionExecutor = transientValues.transitionExecutor
val continuation = processEvent(transitionExecutor, event) val continuation = processEvent(transitionExecutor, event)
checkDbTransaction(isDbTransactionOpenOnExit) checkDbTransaction(isDbTransactionOpenOnExit)
return continuation return continuation
@ -270,7 +286,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
private fun openThreadLocalWormhole() { private fun openThreadLocalWormhole() {
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal val threadLocal = transientValues.database.hikariPoolThreadLocal
if (threadLocal != null) { if (threadLocal != null) {
val valueFromThread = swappedOutThreadLocalValue(threadLocal) val valueFromThread = swappedOutThreadLocalValue(threadLocal)
threadLocal.set(valueFromThread) threadLocal.set(valueFromThread)
@ -332,7 +348,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
recordDuration(startTime) recordDuration(startTime)
getTransientField(TransientValues::unfinishedFibers).countDown() transientValues.unfinishedFibers.countDown()
} }
@Suspendable @Suspendable
@ -476,7 +492,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R { override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) val serializationContext = TransientReference(transientValues.checkpointSerializationContext)
val transaction = extractThreadLocalTransaction() val transaction = extractThreadLocalTransaction()
parkAndSerialize { _, _ -> parkAndSerialize { _, _ ->
setLoggingContext() setLoggingContext()
@ -524,13 +540,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return subFlowStack.any { IdempotentFlow::class.java.isAssignableFrom(it.flowClass) } return subFlowStack.any { IdempotentFlow::class.java.isAssignableFrom(it.flowClass) }
} }
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
@Suspendable @Suspendable
override fun scheduleEvent(event: Event) { override fun scheduleEvent(event: Event) {
getTransientField(TransientValues::eventQueue).send(event) transientValues.eventQueue.send(event)
} }
override fun snapshot(): StateMachineState { override fun snapshot(): StateMachineState {
return transientState!!.value return transientState
} }
/** /**
@ -538,13 +560,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
* retried. * retried.
*/ */
override fun updateTimedFlowTimeout(timeoutSeconds: Long) { override fun updateTimedFlowTimeout(timeoutSeconds: Long) {
getTransientField(TransientValues::waitTimeUpdateHook).invoke(id, timeoutSeconds) transientValues.waitTimeUpdateHook.invoke(id, timeoutSeconds)
} }
override val stateMachine get() = getTransientField(TransientValues::stateMachine)
override val instanceId: StateMachineInstanceId get() = StateMachineInstanceId(id, super.getId())
/** /**
* Records the duration of this flow from call() to completion or failure. * Records the duration of this flow from call() to completion or failure.
* Note that the duration will include the time the flow spent being parked, and not just the total * Note that the duration will include the time the flow spent being parked, and not just the total

View File

@ -261,14 +261,9 @@ internal class SingleThreadedStateMachineManager(
unfinishedFibers.countDown() unfinishedFibers.countDown()
val state = flow.fiber.transientState val state = flow.fiber.transientState
return@withLock if (state != null) { state.isKilled = true
state.value.isKilled = true flow.fiber.scheduleEvent(Event.DoRemainingWork)
flow.fiber.scheduleEvent(Event.DoRemainingWork) true
true
} else {
logger.info("Flow $id has not been initialised correctly and cannot be killed")
false
}
} else { } else {
// It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists. // It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) } database.transaction { checkpointStorage.removeCheckpoint(id) }
@ -386,7 +381,7 @@ internal class SingleThreadedStateMachineManager(
currentState.cancelFutureIfRunning() currentState.cancelFutureIfRunning()
// Get set of external events // Get set of external events
val flowId = currentState.flowLogic.runId val flowId = currentState.flowLogic.runId
val oldFlowLeftOver = innerState.withLock { flows[flowId] }?.fiber?.transientValues?.value?.eventQueue val oldFlowLeftOver = innerState.withLock { flows[flowId] }?.fiber?.transientValues?.eventQueue
if (oldFlowLeftOver == null) { if (oldFlowLeftOver == null) {
logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.") logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.")
return return
@ -592,7 +587,7 @@ internal class SingleThreadedStateMachineManager(
): CordaFuture<FlowStateMachine<A>> { ): CordaFuture<FlowStateMachine<A>> {
val existingFlow = innerState.withLock { flows[flowId] } val existingFlow = innerState.withLock { flows[flowId] }
val existingCheckpoint = if (existingFlow != null && existingFlow.fiber.transientState?.value?.isAnyCheckpointPersisted == true) { val existingCheckpoint = if (existingFlow != null && existingFlow.fiber.transientState.isAnyCheckpointPersisted) {
// Load the flow's checkpoint // Load the flow's checkpoint
// The checkpoint will be missing if the flow failed before persisting the original checkpoint // The checkpoint will be missing if the flow failed before persisting the original checkpoint
// CORDA-3359 - Do not start/retry a flow that failed after deleting its checkpoint (the whole of the flow might replay) // CORDA-3359 - Do not start/retry a flow that failed after deleting its checkpoint (the whole of the flow might replay)
@ -756,7 +751,7 @@ internal class SingleThreadedStateMachineManager(
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here. // The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
private fun drainFlowEventQueue(flow: Flow<*>) { private fun drainFlowEventQueue(flow: Flow<*>) {
while (true) { while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return val event = flow.fiber.transientValues.eventQueue.tryReceive() ?: return
when (event) { when (event) {
is Event.DoRemainingWork -> {} is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> { is Event.DeliverSessionMessage -> {

View File

@ -1,5 +1,9 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.Destination import net.corda.core.flows.Destination
@ -15,6 +19,7 @@ import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import java.lang.IllegalStateException
import java.time.Instant import java.time.Instant
import java.util.concurrent.Future import java.util.concurrent.Future
@ -55,7 +60,15 @@ data class StateMachineState(
@Volatile @Volatile
var isKilled: Boolean, var isKilled: Boolean,
val senderUUID: String? val senderUUID: String?
) ) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized")
}
override fun read(kryo: Kryo?, input: Input?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be deserialized")
}
}
/** /**
* @param checkpointState the state of the checkpoint * @param checkpointState the state of the checkpoint
@ -106,6 +119,7 @@ data class Checkpoint(
invocationContext, invocationContext,
ourIdentity, ourIdentity,
emptyMap(), emptyMap(),
emptySet(),
listOf(topLevelSubFlow), listOf(topLevelSubFlow),
numberOfSuspends = 0 numberOfSuspends = 0
), ),
@ -132,6 +146,22 @@ data class Checkpoint(
return copy(checkpointState = checkpointState.copy(sessions = checkpointState.sessions + session)) return copy(checkpointState = checkpointState.copy(sessions = checkpointState.sessions + session))
} }
fun addSessionsToBeClosed(sessionIds: Set<SessionId>): Checkpoint {
return copy(checkpointState = checkpointState.copy(sessionsToBeClosed = checkpointState.sessionsToBeClosed + sessionIds))
}
fun removeSessionsToBeClosed(sessionIds: Set<SessionId>): Checkpoint {
return copy(checkpointState = checkpointState.copy(sessionsToBeClosed = checkpointState.sessionsToBeClosed - sessionIds))
}
/**
* Returns a copy of the Checkpoint with the specified session removed from the session map.
* @param sessionIds the sessions to remove.
*/
fun removeSessions(sessionIds: Set<SessionId>): Checkpoint {
return copy(checkpointState = checkpointState.copy(sessions = checkpointState.sessions - sessionIds))
}
/** /**
* Returns a copy of the Checkpoint with a new subFlow stack. * Returns a copy of the Checkpoint with a new subFlow stack.
* @param subFlows the new List of subFlows. * @param subFlows the new List of subFlows.
@ -193,16 +223,18 @@ data class Checkpoint(
* @param invocationContext the initiator of the flow. * @param invocationContext the initiator of the flow.
* @param ourIdentity the identity the flow is run as. * @param ourIdentity the identity the flow is run as.
* @param sessions map of source session ID to session state. * @param sessions map of source session ID to session state.
* @param sessionsToBeClosed the sessions that have pending session end messages and need to be closed. This is available to avoid scanning all the sessions.
* @param subFlowStack the stack of currently executing subflows. * @param subFlowStack the stack of currently executing subflows.
* @param numberOfSuspends the number of flow suspends due to IO API calls. * @param numberOfSuspends the number of flow suspends due to IO API calls.
*/ */
@CordaSerializable @CordaSerializable
data class CheckpointState( data class CheckpointState(
val invocationContext: InvocationContext, val invocationContext: InvocationContext,
val ourIdentity: Party, val ourIdentity: Party,
val sessions: SessionMap, // This must preserve the insertion order! val sessions: SessionMap, // This must preserve the insertion order!
val subFlowStack: List<SubFlow>, val sessionsToBeClosed: Set<SessionId>,
val numberOfSuspends: Int val subFlowStack: List<SubFlow>,
val numberOfSuspends: Int
) )
/** /**
@ -236,30 +268,25 @@ sealed class SessionState {
/** /**
* We have received a confirmation, the peer party and session id is resolved. * We have received a confirmation, the peer party and session id is resolved.
* @property errors if not empty the session is in an errored state. * @property receivedMessages the messages that have been received and are pending processing.
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
* * [DataSessionMessage]
* * [ErrorSessionMessage]
* * [EndSessionMessage]
* @property otherSideErrored whether the session has received an error from the other side.
*/ */
data class Initiated( data class Initiated(
val peerParty: Party, val peerParty: Party,
val peerFlowInfo: FlowInfo, val peerFlowInfo: FlowInfo,
val receivedMessages: List<DataSessionMessage>, val receivedMessages: List<ExistingSessionMessagePayload>,
val initiatedState: InitiatedSessionState, val otherSideErrored: Boolean,
val errors: List<FlowError>, val peerSinkSessionId: SessionId,
override val deduplicationSeed: String override val deduplicationSeed: String
) : SessionState() ) : SessionState()
} }
typealias SessionMap = Map<SessionId, SessionState> typealias SessionMap = Map<SessionId, SessionState>
/**
* Tracks whether an initiated session state is live or has ended. This is a separate state, as we still need the rest
* of [SessionState.Initiated], even when the session has ended, for un-drained session messages and potential future
* [FlowInfo] requests.
*/
sealed class InitiatedSessionState {
data class Live(val peerSinkSessionId: SessionId) : InitiatedSessionState()
object Ended : InitiatedSessionState() { override fun toString() = "Ended" }
}
/** /**
* Represents the way the flow has started. * Represents the way the flow has started.
*/ */

View File

@ -1,9 +1,8 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party import net.corda.core.utilities.contextLogger
import net.corda.core.internal.DeclaredField import net.corda.core.utilities.debug
import net.corda.node.services.statemachine.Action import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage import net.corda.node.services.statemachine.DataSessionMessage
@ -12,7 +11,7 @@ import net.corda.node.services.statemachine.ErrorSessionMessage
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.ExistingSessionMessage import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.InitiatedSessionState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.RejectSessionMessage import net.corda.node.services.statemachine.RejectSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
@ -37,6 +36,11 @@ class DeliverSessionMessageTransition(
override val startingState: StateMachineState, override val startingState: StateMachineState,
val event: Event.DeliverSessionMessage val event: Event.DeliverSessionMessage
) : Transition { ) : Transition {
private companion object {
val log = contextLogger()
}
override fun transition(): TransitionResult { override fun transition(): TransitionResult {
return builder { return builder {
// Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know // Add the DeduplicationHandler to the pending ones ASAP so in case an error happens we still know
@ -49,7 +53,7 @@ class DeliverSessionMessageTransition(
// Check whether we have a session corresponding to the message. // Check whether we have a session corresponding to the message.
val existingSession = startingState.checkpoint.checkpointState.sessions[event.sessionMessage.recipientSessionId] val existingSession = startingState.checkpoint.checkpointState.sessions[event.sessionMessage.recipientSessionId]
if (existingSession == null) { if (existingSession == null) {
freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId)) checkIfMissingSessionIsAnIssue(event.sessionMessage)
} else { } else {
val payload = event.sessionMessage.payload val payload = event.sessionMessage.payload
// Dispatch based on what kind of message it is. // Dispatch based on what kind of message it is.
@ -58,7 +62,7 @@ class DeliverSessionMessageTransition(
is DataSessionMessage -> dataMessageTransition(existingSession, payload) is DataSessionMessage -> dataMessageTransition(existingSession, payload)
is ErrorSessionMessage -> errorMessageTransition(existingSession, payload) is ErrorSessionMessage -> errorMessageTransition(existingSession, payload)
is RejectSessionMessage -> rejectMessageTransition(existingSession, payload) is RejectSessionMessage -> rejectMessageTransition(existingSession, payload)
is EndSessionMessage -> endMessageTransition() is EndSessionMessage -> endMessageTransition(payload)
} }
} }
// Schedule a DoRemainingWork to check whether the flow needs to be woken up. // Schedule a DoRemainingWork to check whether the flow needs to be woken up.
@ -67,6 +71,14 @@ class DeliverSessionMessageTransition(
} }
} }
private fun TransitionBuilder.checkIfMissingSessionIsAnIssue(message: ExistingSessionMessage) {
val payload = message.payload
if (payload is EndSessionMessage)
log.debug { "Received session end message for a session that has already ended: ${event.sessionMessage.recipientSessionId}"}
else
freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId))
}
private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) { private fun TransitionBuilder.confirmMessageTransition(sessionState: SessionState, message: ConfirmSessionMessage) {
// We received a confirmation message. The corresponding session state must be Initiating. // We received a confirmation message. The corresponding session state must be Initiating.
when (sessionState) { when (sessionState) {
@ -76,9 +88,9 @@ class DeliverSessionMessageTransition(
peerParty = event.sender, peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo, peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(), receivedMessages = emptyList(),
initiatedState = InitiatedSessionState.Live(message.initiatedSessionId), peerSinkSessionId = message.initiatedSessionId,
errors = emptyList(), deduplicationSeed = sessionState.deduplicationSeed,
deduplicationSeed = sessionState.deduplicationSeed otherSideErrored = false
) )
val newCheckpoint = currentState.checkpoint.addSession( val newCheckpoint = currentState.checkpoint.addSession(
event.sessionMessage.recipientSessionId to initiatedSession event.sessionMessage.recipientSessionId to initiatedSession
@ -115,28 +127,11 @@ class DeliverSessionMessageTransition(
} }
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) { private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
val exception: Throwable = if (payload.flowException == null) {
UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = payload.errorId)
} else {
payload.flowException.originalErrorId = payload.errorId
payload.flowException
}
return when (sessionState) { return when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated -> {
when (exception) {
// reflection used to access private field
is UnexpectedFlowEndException -> DeclaredField<Party?>(
UnexpectedFlowEndException::class.java,
"peer",
exception
).value = sessionState.peerParty
is FlowException -> DeclaredField<Party?>(FlowException::class.java, "peer", exception).value = sessionState.peerParty
}
val checkpoint = currentState.checkpoint val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception) val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
val newSessionState = sessionState.copy(errors = sessionState.errors + flowError)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = checkpoint.addSession(sessionId to newSessionState) checkpoint = checkpoint.addSession(sessionId to newSessionState)
) )
@ -165,23 +160,26 @@ class DeliverSessionMessageTransition(
} }
} }
private fun TransitionBuilder.endMessageTransition() { private fun TransitionBuilder.endMessageTransition(payload: EndSessionMessage) {
val sessionId = event.sessionMessage.recipientSessionId val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.checkpointState.sessions val sessions = currentState.checkpoint.checkpointState.sessions
val sessionState = sessions[sessionId] // a check has already been performed to confirm the session exists for this message before this method is invoked.
if (sessionState == null) { val sessionState = sessions[sessionId]!!
return freshErrorTransition(CannotFindSessionException(sessionId))
}
when (sessionState) { when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated -> {
val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended) val flowState = currentState.checkpoint.flowState
currentState = currentState.copy( // flow must have already been started when session end messages are being delivered.
checkpoint = currentState.checkpoint.addSession(sessionId to newSessionState) if (flowState !is FlowState.Started)
return freshErrorTransition(UnexpectedEventInState())
) val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
.addSessionsToBeClosed(setOf(event.sessionMessage.recipientSessionId))
currentState = currentState.copy(checkpoint = newCheckpoint)
} }
else -> { else -> {
freshErrorTransition(UnexpectedEventInState()) freshErrorTransition(PrematureSessionEndException(event.sessionMessage.recipientSessionId))
} }
} }
} }

View File

@ -117,8 +117,9 @@ class ErrorFlowTransition(
sessionState sessionState
} }
} }
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session -> val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && session.errors.isEmpty()) { if (session is SessionState.Initiated && !session.otherSideErrored) {
session session
} else { } else {
null null

View File

@ -105,8 +105,9 @@ class KilledFlowTransition(
sessionState sessionState
} }
} }
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session -> val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && session.errors.isEmpty()) { if (session is SessionState.Initiated && !session.otherSideErrored) {
session session
} else { } else {
null null

View File

@ -1,13 +1,18 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowSession import net.corda.core.flows.FlowSession
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.DeclaredField
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.*
import java.lang.IllegalStateException import org.slf4j.Logger
import kotlin.collections.LinkedHashMap
/** /**
* This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request * This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request
@ -20,28 +25,62 @@ class StartedFlowTransition(
override val startingState: StateMachineState, override val startingState: StateMachineState,
val started: FlowState.Started val started: FlowState.Started
) : Transition { ) : Transition {
companion object {
private val logger: Logger = contextLogger()
}
override fun transition(): TransitionResult { override fun transition(): TransitionResult {
val flowIORequest = started.flowIORequest val flowIORequest = started.flowIORequest
val checkpoint = startingState.checkpoint val (newState, errorsToThrow) = collectRelevantErrorsToThrow(startingState, flowIORequest)
val errorsToThrow = collectRelevantErrorsToThrow(flowIORequest, checkpoint)
if (errorsToThrow.isNotEmpty()) { if (errorsToThrow.isNotEmpty()) {
return TransitionResult( return TransitionResult(
newState = startingState.copy(isFlowResumed = true), newState = newState.copy(isFlowResumed = true),
// throw the first exception. TODO should this aggregate all of them somehow? // throw the first exception. TODO should this aggregate all of them somehow?
actions = listOf(Action.CreateTransaction), actions = listOf(Action.CreateTransaction),
continuation = FlowContinuation.Throw(errorsToThrow[0]) continuation = FlowContinuation.Throw(errorsToThrow[0])
) )
} }
return when (flowIORequest) { val sessionsToBeTerminated = findSessionsToBeTerminated(startingState)
is FlowIORequest.Send -> sendTransition(flowIORequest) // if there are sessions to be closed, we close them as part of this transition and normal processing will continue on the next transition.
is FlowIORequest.Receive -> receiveTransition(flowIORequest) return if (sessionsToBeTerminated.isNotEmpty()) {
is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest) terminateSessions(sessionsToBeTerminated)
is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest) } else {
is FlowIORequest.Sleep -> sleepTransition(flowIORequest) when (flowIORequest) {
is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest) is FlowIORequest.Send -> sendTransition(flowIORequest)
is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition() is FlowIORequest.Receive -> receiveTransition(flowIORequest)
is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest) is FlowIORequest.SendAndReceive -> sendAndReceiveTransition(flowIORequest)
FlowIORequest.ForceCheckpoint -> executeForceCheckpoint() is FlowIORequest.CloseSessions -> closeSessionTransition(flowIORequest)
is FlowIORequest.WaitForLedgerCommit -> waitForLedgerCommitTransition(flowIORequest)
is FlowIORequest.Sleep -> sleepTransition(flowIORequest)
is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest)
is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition()
is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest)
FlowIORequest.ForceCheckpoint -> executeForceCheckpoint()
}
}
}
private fun findSessionsToBeTerminated(startingState: StateMachineState): SessionMap {
return startingState.checkpoint.checkpointState.sessionsToBeClosed.mapNotNull { sessionId ->
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!! as SessionState.Initiated
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is EndSessionMessage) {
sessionId to sessionState
} else {
null
}
}.toMap()
}
private fun terminateSessions(sessionsToBeTerminated: SessionMap): TransitionResult {
return builder {
val sessionsToRemove = sessionsToBeTerminated.keys
val newCheckpoint = currentState.checkpoint.removeSessions(sessionsToRemove)
.removeSessionsToBeClosed(sessionsToRemove)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions.add(Action.RemoveSessionBindings(sessionsToRemove))
actions.add(Action.ScheduleEvent(Event.DoRemainingWork))
FlowContinuation.ProcessEvents
} }
} }
@ -149,6 +188,34 @@ class StartedFlowTransition(
} }
} }
private fun closeSessionTransition(flowIORequest: FlowIORequest.CloseSessions): TransitionResult {
return builder {
val sessionIdsToRemove = flowIORequest.sessions.map { sessionToSessionId(it) }.toSet()
val existingSessionsToRemove = currentState.checkpoint.checkpointState.sessions.filter { (sessionId, _) ->
sessionIdsToRemove.contains(sessionId)
}
val alreadyClosedSessions = sessionIdsToRemove.filter { sessionId -> sessionId !in existingSessionsToRemove }
if (alreadyClosedSessions.isNotEmpty()) {
logger.warn("Attempting to close already closed sessions: $alreadyClosedSessions")
}
if (existingSessionsToRemove.isNotEmpty()) {
val sendEndMessageActions = existingSessionsToRemove.values.mapIndexed { index, state ->
val sinkSessionId = (state as SessionState.Initiated).peerSinkSessionId
val message = ExistingSessionMessage(sinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
}
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys))
actions.add(Action.RemoveSessionBindings(sessionIdsToRemove))
actions.add(Action.SendMultiple(emptyList(), sendEndMessageActions))
}
resumeFlowLogic(Unit)
}
}
private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult { private fun receiveTransition(flowIORequest: FlowIORequest.Receive): TransitionResult {
return builder { return builder {
val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>() val sessionIdToSession = LinkedHashMap<SessionId, FlowSessionImpl>()
@ -199,7 +266,8 @@ class StartedFlowTransition(
someNotFound = true someNotFound = true
} else { } else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
resultMessages[sessionId] = messages[0].payload // at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message.
resultMessages[sessionId] = (messages[0] as DataSessionMessage).payload
} }
} }
else -> { else -> {
@ -257,12 +325,6 @@ class StartedFlowTransition(
val checkpoint = startingState.checkpoint val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions) val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions)
var index = 0 var index = 0
for ((sourceSessionId, _) in sourceSessionIdToMessage) {
val existingSessionState = checkpoint.checkpointState.sessions[sourceSessionId] ?: return freshErrorTransition(CannotFindSessionException(sourceSessionId))
if (existingSessionState is SessionState.Initiated && existingSessionState.initiatedState is InitiatedSessionState.Ended) {
return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId"))
}
}
val messagesByType = sourceSessionIdToMessage.toList() val messagesByType = sourceSessionIdToMessage.toList()
.map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) } .map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) }
@ -286,17 +348,13 @@ class StartedFlowTransition(
val newBufferedMessages = initiatingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage) val newBufferedMessages = initiatingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
newSessions[sourceSessionId] = initiatingSessionState.copy(bufferedMessages = newBufferedMessages) newSessions[sourceSessionId] = initiatingSessionState.copy(bufferedMessages = newBufferedMessages)
} }
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.mapNotNull {(_, sessionState, message) -> val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(_, sessionState, message) ->
val initiatedSessionState = sessionState as SessionState.Initiated val initiatedSessionState = sessionState as SessionState.Initiated
if (initiatedSessionState.initiatedState !is InitiatedSessionState.Live) val sessionMessage = DataSessionMessage(message)
null val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatedSessionState)
else { val sinkSessionId = initiatedSessionState.peerSinkSessionId
val sessionMessage = DataSessionMessage(message) val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatedSessionState) Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
val sinkSessionId = initiatedSessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
}
} ?: emptyList() } ?: emptyList()
if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) { if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) {
@ -309,21 +367,68 @@ class StartedFlowTransition(
return (session as FlowSessionImpl).sourceSessionId return (session as FlowSessionImpl).sourceSessionId
} }
private fun collectErroredSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { private fun collectErroredSessionErrors(startingState: StateMachineState, sessionIds: Collection<SessionId>): Pair<StateMachineState, List<Throwable>> {
return sessionIds.flatMap { sessionId -> var newState = startingState
val sessionState = checkpoint.checkpointState.sessions[sessionId]!! val errors = sessionIds.filter { sessionId ->
when (sessionState) { startingState.checkpoint.checkpointState.sessions.containsKey(sessionId)
is SessionState.Uninitiated -> emptyList() }.flatMap { sessionId ->
is SessionState.Initiating -> { val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!!
if (sessionState.rejectionError == null) { when (sessionState) {
emptyList() is SessionState.Uninitiated -> emptyList()
} else { is SessionState.Initiating -> {
listOf(sessionState.rejectionError.exception) if (sessionState.rejectionError == null) {
emptyList()
} else {
listOf(sessionState.rejectionError.exception)
}
}
is SessionState.Initiated -> {
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) {
val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage
val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty)
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true)
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState)
newState = startingState.copy(checkpoint = newCheckpoint)
listOf(exception)
} else {
emptyList()
}
}
} }
} }
is SessionState.Initiated -> sessionState.errors.map(FlowError::exception) return Pair(newState, errors)
} }
private fun convertErrorMessageToException(errorMessage: ErrorSessionMessage, peer: Party): Throwable {
val exception: Throwable = if (errorMessage.flowException == null) {
UnexpectedFlowEndException("Counter-flow errored", cause = null, originalErrorId = errorMessage.errorId)
} else {
errorMessage.flowException.originalErrorId = errorMessage.errorId
errorMessage.flowException
} }
when (exception) {
// reflection used to access private field
is UnexpectedFlowEndException -> DeclaredField<Party?>(
UnexpectedFlowEndException::class.java,
"peer",
exception
).value = peer
is FlowException -> DeclaredField<Party?>(FlowException::class.java, "peer", exception).value = peer
}
return exception
}
private fun collectUncloseableSessions(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
val uninitialisedSessions = sessionIds.mapNotNull { sessionId ->
if (!checkpoint.checkpointState.sessions.containsKey(sessionId))
null
else
sessionId to checkpoint.checkpointState.sessions[sessionId]
}
.filter { (_, sessionState) -> sessionState !is SessionState.Initiated }
.map { it.first }
return uninitialisedSessions.map { PrematureSessionCloseException(it) }
} }
private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> { private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> {
@ -333,77 +438,64 @@ class StartedFlowTransition(
} }
private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> {
return sessionIds.mapNotNull { sessionId -> return sessionIds.filter { sessionId ->
val sessionState = checkpoint.checkpointState.sessions[sessionId]!! !checkpoint.checkpointState.sessions.containsKey(sessionId)
when (sessionState) { }.map {sessionId ->
is SessionState.Initiated -> { UnexpectedFlowEndException(
if (sessionState.initiatedState === InitiatedSessionState.Ended) { "Tried to access ended session $sessionId",
UnexpectedFlowEndException( cause = null,
"Tried to access ended session $sessionId", originalErrorId = context.secureRandom.nextLong()
cause = null, )
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
} }
} }
private fun collectEndedEmptySessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { private fun collectRelevantErrorsToThrow(startingState: StateMachineState, flowIORequest: FlowIORequest<*>): Pair<StateMachineState, List<Throwable>> {
return sessionIds.mapNotNull { sessionId ->
val sessionState = checkpoint.checkpointState.sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
if (sessionState.initiatedState === InitiatedSessionState.Ended &&
sessionState.receivedMessages.isEmpty()) {
UnexpectedFlowEndException(
"Tried to access ended session $sessionId with empty buffer",
cause = null,
originalErrorId = context.secureRandom.nextLong()
)
} else {
null
}
}
else -> null
}
}
}
private fun collectRelevantErrorsToThrow(flowIORequest: FlowIORequest<*>, checkpoint: Checkpoint): List<Throwable> {
return when (flowIORequest) { return when (flowIORequest) {
is FlowIORequest.Send -> { is FlowIORequest.Send -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) val (newState, erroredSessionErrors) = collectErroredSessionErrors(startingState, sessionIds)
val endedSessionErrors = collectEndedSessionErrors(sessionIds, startingState.checkpoint)
Pair(newState, erroredSessionErrors + endedSessionErrors)
} }
is FlowIORequest.Receive -> { is FlowIORequest.Receive -> {
val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId) val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedEmptySessionErrors(sessionIds, checkpoint) val (newState, erroredSessionErrors) = collectErroredSessionErrors(startingState, sessionIds)
val endedSessionErrors = collectEndedSessionErrors(sessionIds, startingState.checkpoint)
Pair(newState, erroredSessionErrors + endedSessionErrors)
} }
is FlowIORequest.SendAndReceive -> { is FlowIORequest.SendAndReceive -> {
val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId) val sessionIds = flowIORequest.sessionToMessage.keys.map(this::sessionToSessionId)
collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) val (newState, erroredSessionErrors) = collectErroredSessionErrors(startingState, sessionIds)
val endedSessionErrors = collectEndedSessionErrors(sessionIds, startingState.checkpoint)
Pair(newState, erroredSessionErrors + endedSessionErrors)
} }
is FlowIORequest.WaitForLedgerCommit -> { is FlowIORequest.WaitForLedgerCommit -> {
collectErroredSessionErrors(checkpoint.checkpointState.sessions.keys, checkpoint) return collectErroredSessionErrors(startingState, startingState.checkpoint.checkpointState.sessions.keys)
} }
is FlowIORequest.GetFlowInfo -> { is FlowIORequest.GetFlowInfo -> {
collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint) val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
val (newState, erroredSessionErrors) = collectErroredSessionErrors(startingState, sessionIds)
val endedSessionErrors = collectEndedSessionErrors(sessionIds, startingState.checkpoint)
Pair(newState, erroredSessionErrors + endedSessionErrors)
}
is FlowIORequest.CloseSessions -> {
val sessionIds = flowIORequest.sessions.map(this::sessionToSessionId)
val (newState, erroredSessionErrors) = collectErroredSessionErrors(startingState, sessionIds)
val uncloseableSessionErrors = collectUncloseableSessions(sessionIds, startingState.checkpoint)
Pair(newState, erroredSessionErrors + uncloseableSessionErrors)
} }
is FlowIORequest.Sleep -> { is FlowIORequest.Sleep -> {
emptyList() Pair(startingState, emptyList())
} }
is FlowIORequest.WaitForSessionConfirmations -> { is FlowIORequest.WaitForSessionConfirmations -> {
collectErroredInitiatingSessionErrors(checkpoint) val errors = collectErroredInitiatingSessionErrors(startingState.checkpoint)
Pair(startingState, errors)
} }
is FlowIORequest.ExecuteAsyncOperation<*> -> { is FlowIORequest.ExecuteAsyncOperation<*> -> {
emptyList() Pair(startingState, emptyList())
} }
FlowIORequest.ForceCheckpoint -> { FlowIORequest.ForceCheckpoint -> {
emptyList() Pair(startingState, emptyList())
} }
} }
} }

View File

@ -18,7 +18,6 @@ import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowSessionImpl import net.corda.node.services.statemachine.FlowSessionImpl
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitialSessionMessage import net.corda.node.services.statemachine.InitialSessionMessage
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionMessage import net.corda.node.services.statemachine.SessionMessage
@ -267,8 +266,8 @@ class TopLevelTransition(
private fun TransitionBuilder.sendEndMessages() { private fun TransitionBuilder.sendEndMessages() {
val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state -> val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state ->
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { if (state is SessionState.Initiated) {
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) val message = ExistingSessionMessage(state.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state) val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID)) Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
} else { } else {

View File

@ -81,3 +81,5 @@ class TransitionBuilder(val context: TransitionContext, initialState: StateMachi
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId") class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
class UnexpectedEventInState : IllegalStateException("Unexpected event") class UnexpectedEventInState : IllegalStateException("Unexpected event")
class PrematureSessionCloseException(sessionId: SessionId): IllegalStateException("The following session was closed before it was initialised: $sessionId")
class PrematureSessionEndException(sessionId: SessionId): IllegalStateException("A premature session end message was received before the session was initialised: $sessionId")

View File

@ -8,7 +8,6 @@ import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowStart import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
@ -45,7 +44,7 @@ class UnstartedFlowTransition(
val initiatingMessage = flowStart.initiatingMessage val initiatingMessage = flowStart.initiatingMessage
val initiatedState = SessionState.Initiated( val initiatedState = SessionState.Initiated(
peerParty = flowStart.peerSession.counterparty, peerParty = flowStart.peerSession.counterparty,
initiatedState = InitiatedSessionState.Live(initiatingMessage.initiatorSessionId), peerSinkSessionId = initiatingMessage.initiatorSessionId,
peerFlowInfo = FlowInfo( peerFlowInfo = FlowInfo(
flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion, flowVersion = flowStart.senderCoreFlowVersion ?: initiatingMessage.flowVersion,
appName = initiatingMessage.appName appName = initiatingMessage.appName
@ -55,8 +54,8 @@ class UnstartedFlowTransition(
} else { } else {
listOf(DataSessionMessage(initiatingMessage.firstPayload)) listOf(DataSessionMessage(initiatingMessage.firstPayload))
}, },
errors = emptyList(), deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}",
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}" otherSideErrored = false
) )
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)

View File

@ -26,6 +26,7 @@ import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.concurrent.flatMap import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.declaredField
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.queryBy import net.corda.core.node.services.queryBy
@ -173,9 +174,12 @@ class FlowFrameworkTests {
val flow = ReceiveFlow(bob) val flow = ReceiveFlow(bob)
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception // Before the flow runs change the suspend action to throw an exception
val throwingActionExecutor = SuspendThrowingActionExecutor(Exception("Thrown during suspend"), val throwingActionExecutor = SuspendThrowingActionExecutor(
fiber.transientValues!!.value.actionExecutor) Exception("Thrown during suspend"),
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor)) fiber.transientValues.actionExecutor
)
fiber.declaredField<TransientReference<FlowStateMachineImpl.TransientValues>>("transientValuesReference").value =
TransientReference(fiber.transientValues.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork() mockNet.runNetwork()
fiber.resultFuture.getOrThrow() fiber.resultFuture.getOrThrow()
assertThat(aliceNode.smm.allStateMachines).isEmpty() assertThat(aliceNode.smm.allStateMachines).isEmpty()
@ -201,7 +205,7 @@ class FlowFrameworkTests {
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `other side ends before doing expected send`() { fun `other side ends before doing expected send`() {
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow() } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
@ -679,14 +683,14 @@ class FlowFrameworkTests {
SuspendingFlow.hookBeforeCheckpoint = { SuspendingFlow.hookBeforeCheckpoint = {
val flowFiber = this as? FlowStateMachineImpl<*> val flowFiber = this as? FlowStateMachineImpl<*>
flowState = flowFiber!!.transientState!!.value.checkpoint.flowState flowState = flowFiber!!.transientState.checkpoint.flowState
if (firstExecution) { if (firstExecution) {
throw HospitalizeFlowException() throw HospitalizeFlowException()
} else { } else {
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
currentDBSession().clear() // clear session as Hibernate with fails with 'org.hibernate.NonUniqueObjectException' once it tries to save a DBFlowCheckpoint upon checkpoint currentDBSession().clear() // clear session as Hibernate with fails with 'org.hibernate.NonUniqueObjectException' once it tries to save a DBFlowCheckpoint upon checkpoint
inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState!!.value.checkpoint.status inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState.checkpoint.status
futureFiber.complete(flowFiber) futureFiber.complete(flowFiber)
} }
@ -701,7 +705,7 @@ class FlowFrameworkTests {
} }
// flow is in hospital // flow is in hospital
assertTrue(flowState is FlowState.Unstarted) assertTrue(flowState is FlowState.Unstarted)
val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState?.value?.checkpoint?.status val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState.checkpoint.status
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus) assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus)
aliceNode.database.transaction { aliceNode.database.transaction {
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
@ -727,13 +731,13 @@ class FlowFrameworkTests {
SuspendingFlow.hookAfterCheckpoint = { SuspendingFlow.hookAfterCheckpoint = {
val flowFiber = this as? FlowStateMachineImpl<*> val flowFiber = this as? FlowStateMachineImpl<*>
flowState = flowFiber!!.transientState!!.value.checkpoint.flowState flowState = flowFiber!!.transientState.checkpoint.flowState
if (firstExecution) { if (firstExecution) {
throw HospitalizeFlowException() throw HospitalizeFlowException()
} else { } else {
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber.transientState!!.value.checkpoint.status inMemoryCheckpointStatus = flowFiber.transientState.checkpoint.status
futureFiber.complete(flowFiber) futureFiber.complete(flowFiber)
} }
@ -820,7 +824,7 @@ class FlowFrameworkTests {
} else { } else {
val flowFiber = this as? FlowStateMachineImpl<*> val flowFiber = this as? FlowStateMachineImpl<*>
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber!!.transientState!!.value.checkpoint.status inMemoryCheckpointStatus = flowFiber!!.transientState.checkpoint.status
persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails
} }
} }
@ -868,6 +872,7 @@ class FlowFrameworkTests {
session.send(1) session.send(1)
// ... then pause this one until it's received the session-end message from the other side // ... then pause this one until it's received the session-end message from the other side
receivedOtherFlowEnd.acquire() receivedOtherFlowEnd.acquire()
session.sendAndReceive<Int>(2) session.sendAndReceive<Int>(2)
} }
} }

View File

@ -247,7 +247,7 @@ class FlowMetadataRecordingTest {
it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT)
) )
assertThat(it.launchingCordapp).contains("custom-cordapp") assertThat(it.launchingCordapp).contains("custom-cordapp")
assertEquals(7, it.platformVersion) assertEquals(8, it.platformVersion)
assertEquals(nodeAHandle.nodeInfo.singleIdentity().name.toString(), it.startedBy) assertEquals(nodeAHandle.nodeInfo.singleIdentity().name.toString(), it.startedBy)
assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant)
assertTrue(it.startInstant >= it.invocationInstant) assertTrue(it.startInstant >= it.invocationInstant)

View File

@ -183,6 +183,11 @@ class RetryFlowMockTest {
override fun send(payload: Any) { override fun send(payload: Any) {
TODO("not implemented") TODO("not implemented")
} }
override fun close() {
TODO("Not yet implemented")
}
}), nodeA.services.newContext()).get() }), nodeA.services.newContext()).get()
records.next() records.next()
// Killing it should remove it. // Killing it should remove it.

View File

@ -1,6 +1,7 @@
package net.corda.serialization.internal package net.corda.serialization.internal
import net.corda.core.KeepForDJVM import net.corda.core.KeepForDJVM
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.EncodingWhitelist import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializationEncoding
@ -13,7 +14,8 @@ data class CheckpointSerializationContextImpl @JvmOverloads constructor(
override val properties: Map<Any, Any>, override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean, override val objectReferencesEnabled: Boolean,
override val encoding: SerializationEncoding?, override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext { override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist,
override val checkpointCustomSerializers: Iterable<CheckpointCustomSerializer<*,*>> = emptyList()) : CheckpointSerializationContext {
override fun withProperty(property: Any, value: Any): CheckpointSerializationContext { override fun withProperty(property: Any, value: Any): CheckpointSerializationContext {
return copy(properties = properties + (property to value)) return copy(properties = properties + (property to value))
} }
@ -34,4 +36,6 @@ data class CheckpointSerializationContextImpl @JvmOverloads constructor(
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist) override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
override fun withCheckpointCustomSerializers(checkpointCustomSerializers : Iterable<CheckpointCustomSerializer<*,*>>)
= copy(checkpointCustomSerializers = checkpointCustomSerializers)
} }

View File

@ -2,6 +2,7 @@ package net.corda.coretesting.internal
import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme
import net.corda.core.internal.createInstancesOfClassesImplementing import net.corda.core.internal.createInstancesOfClassesImplementing
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.SerializationEnvironment
@ -25,8 +26,11 @@ fun createTestSerializationEnv(): SerializationEnvironment {
} }
fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment {
var customCheckpointSerializers: Set<CheckpointCustomSerializer<*, *>> = emptySet()
val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) { val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) {
val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java)
customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java)
val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet()
Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists), Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists),
@ -44,7 +48,7 @@ fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironm
AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT, AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT, AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT, KRYO_CHECKPOINT_CONTEXT.withCheckpointCustomSerializers(customCheckpointSerializers),
KryoCheckpointSerializer KryoCheckpointSerializer
) )
} }

View File

@ -536,7 +536,8 @@ open class InternalMockNetwork(cordappPackages: List<String> = emptyList(),
} }
private fun pumpAll(): Boolean { private fun pumpAll(): Boolean {
val transferredMessages = messagingNetwork.endpoints.map { it.pumpReceive(false) } val transferredMessages = messagingNetwork.endpoints.filter { it.active }
.map { it.pumpReceive(false) }
return transferredMessages.any { it != null } return transferredMessages.any { it != null }
} }

View File

@ -173,6 +173,7 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
it.join() it.join()
} }
running = false running = false
stateHelper.active = false
network.netNodeHasShutdown(myAddress) network.netNodeHasShutdown(myAddress)
} }