diff --git a/constants.properties b/constants.properties index 8fde90cd02..9ed2417171 100644 --- a/constants.properties +++ b/constants.properties @@ -1,4 +1,4 @@ -gradlePluginsVersion=4.0.32 +gradlePluginsVersion=4.0.33 kotlinVersion=1.2.71 # ***************************************************************# # When incrementing platformVersion make sure to update # diff --git a/core/src/main/kotlin/net/corda/core/contracts/Attachment.kt b/core/src/main/kotlin/net/corda/core/contracts/Attachment.kt index d17d053e1f..0535f051e6 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/Attachment.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/Attachment.kt @@ -1,7 +1,6 @@ package net.corda.core.contracts import net.corda.core.KeepForDJVM -import net.corda.core.identity.Party import net.corda.core.internal.extractFile import net.corda.core.serialization.CordaSerializable import java.io.FileNotFoundException diff --git a/core/src/main/kotlin/net/corda/core/contracts/AttachmentConstraint.kt b/core/src/main/kotlin/net/corda/core/contracts/AttachmentConstraint.kt index 76ffc8a866..55be99325b 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/AttachmentConstraint.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/AttachmentConstraint.kt @@ -3,7 +3,6 @@ package net.corda.core.contracts import net.corda.core.DoNotImplement import net.corda.core.KeepForDJVM import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint.isSatisfiedBy -import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.SecureHash import net.corda.core.crypto.isFulfilledBy import net.corda.core.internal.AttachmentWithContext diff --git a/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt b/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt index 8ab16a5190..2d6075a5e3 100644 --- a/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt +++ b/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt @@ -43,7 +43,7 @@ data class CordappImpl( */ override val cordappClasses: List = run { val classList = rpcFlows + initiatedFlows + services + serializationWhitelists.map { javaClass } + notaryService - classList.mapNotNull { it?.name } + contractClassNames + classList.mapNotNull { it?.name } + contractClassNames } // TODO Why a seperate Info class and not just have the fields directly in CordappImpl? diff --git a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt index 99c069de88..e52bda5456 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt @@ -5,7 +5,10 @@ import net.corda.core.CordaInternal import net.corda.core.DeleteForDJVM import net.corda.core.contracts.* import net.corda.core.cordapp.CordappProvider -import net.corda.core.crypto.* +import net.corda.core.crypto.CompositeKey +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.ensureMinimumPlatformVersion diff --git a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt index fbb2d3bad2..ea3b44a98b 100644 --- a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt +++ b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt @@ -1,10 +1,13 @@ package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.toFuture import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.unwrap import net.corda.node.internal.InitiatedFlowFactory import net.corda.testing.node.internal.TestStartedNode +import rx.Observable import kotlin.reflect.KClass /** @@ -34,20 +37,6 @@ class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic() { override fun call() = closure() } -/** - * Allows to register a flow of type [R] against an initiating flow of type [I]. - */ -inline fun , reified R : FlowLogic<*>> TestStartedNode.registerInitiatedFlow(initiatingFlowType: KClass, crossinline construct: (session: FlowSession) -> R) { - registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) -} - -/** - * Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R]. - */ -inline fun , reified R : Any> TestStartedNode.registerAnswer(initiatingFlowType: KClass, value: R) { - registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true) -} - /** * Extracts data from a [Map[FlowSession, UntrustworthyData]] without performing checks and casting to [R]. */ @@ -112,4 +101,23 @@ inline fun FlowLogic<*>.receiveAll(session: FlowSession, varar private fun Array>>.enforceNoDuplicates() { require(this.size == this.toSet().size) { "A flow session can only appear once as argument." } +} + +inline fun > TestStartedNode.registerCordappFlowFactory( + initiatingFlowClass: KClass>, + initiatedFlowVersion: Int = 1, + noinline flowFactory: (FlowSession) -> P): CordaFuture

{ + + val observable = internals.registerInitiatedFlowFactory( + initiatingFlowClass.java, + P::class.java, + InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), + track = true) + return observable.toFuture() +} + +fun > TestStartedNode.registerCoreFlowFactory(initiatingFlowClass: Class>, + initiatedFlowClass: Class, + flowFactory: (FlowSession) -> T , track: Boolean): Observable { + return this.internals.registerInitiatedFlowFactory(initiatingFlowClass, initiatedFlowClass, InitiatedFlowFactory.Core(flowFactory), track) } \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt index c546a06c20..ec56bb1237 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt @@ -2,16 +2,18 @@ package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import com.natpryce.hamkrest.assertion.assert -import net.corda.testing.internal.matchers.flow.willReturn import net.corda.core.flows.mixins.WithMockNet import net.corda.core.identity.Party import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.unwrap import net.corda.testing.core.singleIdentity +import net.corda.testing.internal.matchers.flow.willReturn import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.TestStartedNode import org.assertj.core.api.Assertions.assertThat import org.junit.AfterClass import org.junit.Test +import kotlin.reflect.KClass class ReceiveMultipleFlowTests : WithMockNet { @@ -43,7 +45,7 @@ class ReceiveMultipleFlowTests : WithMockNet { } } - nodes[1].registerInitiatedFlow(initiatingFlow::class) { session -> + nodes[1].registerCordappFlowFactory(initiatingFlow::class) { session -> object : FlowLogic() { @Suspendable override fun call() { @@ -123,4 +125,15 @@ class ReceiveMultipleFlowTests : WithMockNet { return double * string.length } } +} + +private inline fun TestStartedNode.registerAnswer(kClass: KClass>, value1: T) { + this.registerCordappFlowFactory(kClass) { session -> + object : FlowLogic() { + @Suspendable + override fun call() { + session.send(value1!!) + } + } + } } \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/internal/JarSignatureCollectorTest.kt b/core/src/test/kotlin/net/corda/core/internal/JarSignatureCollectorTest.kt index 7f83f23a47..1a379a9f09 100644 --- a/core/src/test/kotlin/net/corda/core/internal/JarSignatureCollectorTest.kt +++ b/core/src/test/kotlin/net/corda/core/internal/JarSignatureCollectorTest.kt @@ -8,7 +8,6 @@ import net.corda.core.JarSignatureTestUtils.updateJar import net.corda.core.identity.Party import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME -import net.corda.testing.core.CHARLIE_NAME import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.AfterClass diff --git a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt index 3ba769ce88..c89e1d0c17 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt @@ -3,16 +3,12 @@ package net.corda.core.serialization import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.Attachment import net.corda.core.crypto.SecureHash -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSession -import net.corda.core.flows.InitiatingFlow -import net.corda.core.flows.TestNoSecurityDataVendingFlow +import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.FetchAttachmentsFlow import net.corda.core.internal.FetchDataFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap -import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.persistence.NodeAttachmentService import net.corda.nodeapi.internal.persistence.currentDBSession import net.corda.testing.core.ALICE_NAME @@ -151,11 +147,10 @@ class AttachmentSerializationTest { } private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) { - server.registerFlowFactory( - ClientLogic::class.java, - InitiatedFlowFactory.Core { ServerLogic(it, sendData) }, - ServerLogic::class.java, - track = false) + server.registerCordappFlowFactory( + ClientLogic::class, + 1 + ) { ServerLogic(it, sendData) } client.services.startFlow(clientLogic) mockNet.runNetwork(rounds) } diff --git a/docs/source/flow-overriding.rst b/docs/source/flow-overriding.rst new file mode 100644 index 0000000000..273ff7fa03 --- /dev/null +++ b/docs/source/flow-overriding.rst @@ -0,0 +1,141 @@ +Configuring Responder Flows +=========================== + +A flow can be a fairly complex thing that interacts with many backend systems, and so it is likely that different users +of a specific CordApp will require differences in how flows interact with their specific infrastructure. + +Corda supports this functionality by providing two mechanisms to modify the behaviour of apps in your node. + +Subclassing a Flow +------------------ + +If you have a workflow which is mostly common, but also requires slight alterations in specific situations, most developers would be familiar +with refactoring into `Base` and `Sub` classes. A simple example is shown below. + +java +~~~~ + + .. code-block:: java + + @InitiatingFlow + public class Initiator extends FlowLogic { + private final Party otherSide; + + public Initiator(Party otherSide) { + this.otherSide = otherSide; + } + + @Override + public String call() throws FlowException { + return initiateFlow(otherSide).receive(String.class).unwrap((it) -> it); + } + } + + @InitiatedBy(Initiator.class) + public class BaseResponder extends FlowLogic { + private FlowSession counterpartySession; + + public BaseResponder(FlowSession counterpartySession) { + super(); + this.counterpartySession = counterpartySession; + } + + @Override + public Void call() throws FlowException { + counterpartySession.send(getMessage()); + return Void; + } + + + protected String getMessage() { + return "This Is the Legacy Responder"; + } + } + + public class SubResponder extends BaseResponder { + + public SubResponder(FlowSession counterpartySession) { + super(counterpartySession); + } + + @Override + protected String getMessage() { + return "This is the sub responder"; + } + } + + + +kotlin +~~~~~~ + + .. code-block:: kotlin + + @InitiatedBy(Initiator::class) + open class BaseResponder(internal val otherSideSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + otherSideSession.send(getMessage()) + } + protected open fun getMessage() = "This Is the Legacy Responder" + } + + @InitiatedBy(Initiator::class) + class SubResponder(otherSideSession: FlowSession) : BaseResponder(otherSideSession) { + override fun getMessage(): String { + return "This is the sub responder" + } + } + + + + + +Corda would detect that both ``BaseResponder`` and ``SubResponder`` are configured for responding to ``Initiator``. +Corda will then calculate the hops to ``FlowLogic`` and select the implementation which is furthest distance, ie: the most subclassed implementation. +In the above example, ``SubResponder`` would be selected as the default responder for ``Initiator`` + +.. note:: The flows do not need to be within the same CordApp, or package, therefore to customise a shared app you obtained from a third party, you'd write your own CorDapp that subclasses the first." + +Overriding a flow via node configuration +---------------------------------------- + +Whilst the subclassing approach is likely to be useful for most applications, there is another mechanism to override this behaviour. +This would be useful if for example, a specific CordApp user requires such a different responder that subclassing an existing flow +would not be a good solution. In this case, it's possible to specify a hardcoded flow via the node configuration. + +The configuration section is named ``flowOverrides`` and it accepts an array of ``overrides`` + +.. container:: codeset + + .. code-block:: json + + flowOverrides { + overrides=[ + { + initiator="net.corda.Initiator" + responder="net.corda.BaseResponder" + } + ] + } + +The cordform plugin also provides a ``flowOverride`` method within the ``deployNodes`` block which can be used to override a flow. In the below example, we will override +the ``SubResponder`` with ``BaseResponder`` + +.. container:: codeset + + .. code-block:: groovy + + node { + name "O=Bank,L=London,C=GB" + p2pPort 10025 + rpcUsers = ext.rpcUsers + rpcSettings { + address "localhost:10026" + adminAddress "localhost:10027" + } + extraConfig = ['h2Settings.address' : 'localhost:10035'] + flowOverride("net.corda.Initiator", "net.corda.BaseResponder") + } + +This will generate the corresponding ``flowOverrides`` section and place it in the configuration for that node. \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowOverrideTests.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowOverrideTests.kt new file mode 100644 index 0000000000..881daf18cf --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowOverrideTests.kt @@ -0,0 +1,85 @@ +package net.corda.node.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.internal.cordappForClasses +import org.hamcrest.CoreMatchers.`is` +import org.junit.Assert +import org.junit.Test + +class FlowOverrideTests { + + @StartableByRPC + @InitiatingFlow + class Ping(private val pongParty: Party) : FlowLogic() { + @Suspendable + override fun call(): String { + val pongSession = initiateFlow(pongParty) + return pongSession.sendAndReceive("PING").unwrap { it } + } + } + + @InitiatedBy(Ping::class) + open class Pong(private val pingSession: FlowSession) : FlowLogic() { + companion object { + val PONG = "PONG" + } + + @Suspendable + override fun call() { + pingSession.send(PONG) + } + } + + @InitiatedBy(Ping::class) + class Pong2(private val pingSession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + pingSession.send("PONGPONG") + } + } + + @InitiatedBy(Ping::class) + class Pongiest(private val pingSession: FlowSession) : Pong(pingSession) { + + companion object { + val GORGONZOLA = "Gorgonzola" + } + + @Suspendable + override fun call() { + pingSession.send(GORGONZOLA) + } + } + + private val nodeAClasses = setOf(Ping::class.java, + Pong::class.java, Pongiest::class.java) + private val nodeBClasses = setOf(Ping::class.java, Pong::class.java) + + @Test + fun `should use the most "specific" implementation of a responding flow`() { + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray()))).getOrThrow() + val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow() + Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pongiest.GORGONZOLA)) + } + } + + @Test + fun `should use the overriden implementation of a responding flow`() { + val flowOverrides = mapOf(Ping::class.java to Pong::class.java) + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray())), flowOverrides = flowOverrides).getOrThrow() + val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow() + Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pong.PONG)) + } + } + +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt index f4b3531b13..1248d75b4f 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt @@ -7,10 +7,11 @@ import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.Party import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap -import net.corda.testing.core.singleIdentity -import net.corda.testing.node.internal.NodeBasedTest +import net.corda.node.internal.NodeFlowManager import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.node.internal.NodeBasedTest import net.corda.testing.node.internal.startFlow import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -18,9 +19,10 @@ import org.junit.Test class FlowVersioningTest : NodeBasedTest() { @Test fun `getFlowContext returns the platform version for core flows`() { + val bobFlowManager = NodeFlowManager() val alice = startNode(ALICE_NAME, platformVersion = 2) - val bob = startNode(BOB_NAME, platformVersion = 3) - bob.node.installCoreFlow(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow) + val bob = startNode(BOB_NAME, platformVersion = 3, flowManager = bobFlowManager) + bobFlowManager.registerInitiatedCoreFlowFactory(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow) val (alicePlatformVersionAccordingToBob, bobPlatformVersionAccordingToAlice) = alice.services.startFlow( PretendInitiatingCoreFlow(bob.info.singleIdentity())).resultFuture.getOrThrow() assertThat(alicePlatformVersionAccordingToBob).isEqualTo(2) @@ -45,4 +47,5 @@ class FlowVersioningTest : NodeBasedTest() { @Suspendable override fun call() = otherSideSession.send(otherSideSession.getCounterpartyFlowInfo().flowVersion) } + } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index d8b9aa15a6..2d12f73a9e 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -30,7 +30,10 @@ import net.corda.core.schemas.MappedSchema import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.core.utilities.* +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.days +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.minutes import net.corda.node.CordaClock import net.corda.node.SerialFilter import net.corda.node.VersionInfo @@ -99,13 +102,10 @@ import java.time.Clock import java.time.Duration import java.time.format.DateTimeParseException import java.util.* -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.SECONDS -import kotlin.collections.set -import kotlin.reflect.KClass import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair /** @@ -120,11 +120,11 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val platformClock: CordaClock, cacheFactoryPrototype: BindableNamedCacheFactory, protected val versionInfo: VersionInfo, + protected val flowManager: FlowManager, protected val serverThread: AffinityExecutor.ServiceAffinityExecutor, private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { protected abstract val log: Logger - @Suppress("LeakingThis") private var tokenizableServices: MutableList? = mutableListOf(platformClock, this) @@ -211,7 +211,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, ).tokenize().closeOnStop() private val cordappServices = MutableClassToInstanceMap.create() - private val flowFactories = ConcurrentHashMap>, InitiatedFlowFactory<*>>() private val shutdownExecutor = Executors.newSingleThreadExecutor() protected abstract val transactionVerifierWorkerCount: Int @@ -237,7 +236,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, private var _started: S? = null private fun T.tokenize(): T { - tokenizableServices?.add(this) ?: throw IllegalStateException("The tokenisable services list has already been finalised") + tokenizableServices?.add(this) + ?: throw IllegalStateException("The tokenisable services list has already been finalised") return this } @@ -607,91 +607,27 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } private fun registerCordappFlows() { - cordappLoader.cordapps.flatMap { it.initiatedFlows } - .forEach { + cordappLoader.cordapps.forEach { cordapp -> + cordapp.initiatedFlows.groupBy { it.requireAnnotation().value.java }.forEach { initiator, responders -> + responders.forEach { responder -> try { - registerInitiatedFlowInternal(smm, it, track = false) + flowManager.registerInitiatedFlow(initiator, responder) } catch (e: NoSuchMethodException) { - log.error("${it.name}, as an initiated flow, must have a constructor with a single parameter " + + log.error("${responder.name}, as an initiated flow, must have a constructor with a single parameter " + "of type ${Party::class.java.name}") - } catch (e: Exception) { - log.error("Unable to register initiated flow ${it.name}", e) + throw e } } - } - - fun > registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class): Observable { - return registerInitiatedFlowInternal(smm, initiatedFlowClass, track = true) - } - - // TODO remove once not needed - private fun deprecatedFlowConstructorMessage(flowClass: Class<*>): String { - return "Installing flow factory for $flowClass accepting a ${Party::class.java.simpleName}, which is deprecated. " + - "It should accept a ${FlowSession::class.java.simpleName} instead" - } - - private fun > registerInitiatedFlowInternal(smm: StateMachineManager, initiatedFlow: Class, track: Boolean): Observable { - val constructors = initiatedFlow.declaredConstructors.associateBy { it.parameterTypes.toList() } - val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true } - val ctor: (FlowSession) -> F = if (flowSessionCtor == null) { - // Try to fallback to a Party constructor - val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true } - if (partyCtor == null) { - throw IllegalArgumentException("$initiatedFlow must have a constructor accepting a ${FlowSession::class.java.name}") - } else { - log.warn(deprecatedFlowConstructorMessage(initiatedFlow)) } - { flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) } - } else { - { flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) } } - val initiatingFlow = initiatedFlow.requireAnnotation().value.java - val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass - require(classWithAnnotation == initiatingFlow) { - "${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiatingFlow.name}" - } - val flowFactory = InitiatedFlowFactory.CorDapp(version, initiatedFlow.appName, ctor) - val observable = internalRegisterFlowFactory(smm, initiatingFlow, flowFactory, initiatedFlow, track) - log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)") - return observable - } - - protected fun > internalRegisterFlowFactory(smm: StateMachineManager, - initiatingFlowClass: Class>, - flowFactory: InitiatedFlowFactory, - initiatedFlowClass: Class, - track: Boolean): Observable { - val observable = if (track) { - smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass) - } else { - Observable.empty() - } - check(initiatingFlowClass !in flowFactories.keys) { - "$initiatingFlowClass is attempting to register multiple initiated flows" - } - flowFactories[initiatingFlowClass] = flowFactory - return observable - } - - /** - * Installs a flow that's core to the Corda platform. Unlike CorDapp flows which are versioned individually using - * [InitiatingFlow.version], core flows have the same version as the node's platform version. To cater for backwards - * compatibility [flowFactory] provides a second parameter which is the platform version of the initiating party. - */ - @VisibleForTesting - fun installCoreFlow(clientFlowClass: KClass>, flowFactory: (FlowSession) -> FlowLogic<*>) { - require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) { - "${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version" - } - flowFactories[clientFlowClass.java] = InitiatedFlowFactory.Core(flowFactory) - log.debug { "Installed core flow ${clientFlowClass.java.name}" } + flowManager.validateRegistrations() } private fun installCoreFlows() { - installCoreFlow(FinalityFlow::class, ::FinalityHandler) - installCoreFlow(NotaryChangeFlow::class, ::NotaryChangeHandler) - installCoreFlow(ContractUpgradeFlow.Initiate::class, ::ContractUpgradeHandler) - installCoreFlow(SwapIdentitiesFlow::class, ::SwapIdentitiesHandler) + flowManager.registerInitiatedCoreFlowFactory(FinalityFlow::class, FinalityHandler::class, ::FinalityHandler) + flowManager.registerInitiatedCoreFlowFactory(NotaryChangeFlow::class, NotaryChangeHandler::class, ::NotaryChangeHandler) + flowManager.registerInitiatedCoreFlowFactory(ContractUpgradeFlow.Initiate::class, NotaryChangeHandler::class, ::ContractUpgradeHandler) + flowManager.registerInitiatedCoreFlowFactory(SwapIdentitiesFlow::class, SwapIdentitiesHandler::class, ::SwapIdentitiesHandler) } protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage { @@ -781,7 +717,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, service.run { tokenize() runOnStop += ::stop - installCoreFlow(NotaryFlow.Client::class, ::createServiceFlow) + flowManager.registerInitiatedCoreFlowFactory(NotaryFlow.Client::class, ::createServiceFlow) start() } return service @@ -961,7 +897,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } override fun getFlowFactory(initiatingFlowClass: Class>): InitiatedFlowFactory<*>? { - return flowFactories[initiatingFlowClass] + return flowManager.getFlowFactoryForInitiatingFlow(initiatingFlowClass) } override fun jdbcSession(): Connection = database.createSession() @@ -1066,4 +1002,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl } // Here we're using the node's RPC key store as the RPC client's trust store. return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword) -} +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/FlowManager.kt b/node/src/main/kotlin/net/corda/node/internal/FlowManager.kt new file mode 100644 index 0000000000..68aa8ff056 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/internal/FlowManager.kt @@ -0,0 +1,222 @@ +package net.corda.node.internal + +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.identity.Party +import net.corda.core.internal.uncheckedCast +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.node.internal.classloading.requireAnnotation +import net.corda.node.services.config.FlowOverrideConfig +import net.corda.node.services.statemachine.appName +import net.corda.node.services.statemachine.flowVersionAndInitiatingClass +import javax.annotation.concurrent.ThreadSafe +import kotlin.reflect.KClass + +/** + * + * This class is responsible for organising which flow should respond to a specific @InitiatingFlow + * + * There are two main ways to modify the behaviour of a cordapp with regards to responding with a different flow + * + * 1.) implementing a new subclass. For example, if we have a ResponderFlow similar to @InitiatedBy(Sender) MyBaseResponder : FlowLogic + * If we subclassed a new Flow with specific logic for DB2, it would be similar to IBMB2Responder() : MyBaseResponder + * When these two flows are encountered by the classpath scan for @InitiatedBy, they will both be selected for responding to Sender + * This implementation will sort them for responding in order of their "depth" from FlowLogic - see: FlowWeightComparator + * So IBMB2Responder would win and it would be selected for responding + * + * 2.) It is possible to specify a flowOverride key in the node configuration. Say we configure a node to have + * flowOverrides{ + * "Sender" = "MyBaseResponder" + * } + * In this case, FlowWeightComparator would detect that there is an override in action, and it will assign MyBaseResponder a maximum weight + * This will result in MyBaseResponder being selected for responding to Sender + * + * + */ +interface FlowManager { + + fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, flowFactory: (FlowSession) -> FlowLogic<*>) + fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, initiatedFlowClass: KClass>?, flowFactory: (FlowSession) -> FlowLogic<*>) + fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, initiatedFlowClass: KClass>?, flowFactory: InitiatedFlowFactory.Core>) + + fun > registerInitiatedFlow(initiator: Class>, responder: Class) + fun > registerInitiatedFlow(responder: Class) + + fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class>): InitiatedFlowFactory<*>? + + fun validateRegistrations() +} + +@ThreadSafe +open class NodeFlowManager(flowOverrides: FlowOverrideConfig? = null) : FlowManager { + + private val flowFactories = HashMap>, MutableList>() + private val flowOverrides = (flowOverrides + ?: FlowOverrideConfig()).overrides.map { it.initiator to it.responder }.toMutableMap() + + companion object { + private val log = contextLogger() + + } + + @Synchronized + override fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class>): InitiatedFlowFactory<*>? { + return flowFactories[initiatedFlowClass]?.firstOrNull()?.flowFactory + } + + @Synchronized + override fun > registerInitiatedFlow(responder: Class) { + return registerInitiatedFlow(responder.requireAnnotation().value.java, responder) + } + + @Synchronized + override fun > registerInitiatedFlow(initiator: Class>, responder: Class) { + val constructors = responder.declaredConstructors.associateBy { it.parameterTypes.toList() } + val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true } + val ctor: (FlowSession) -> F = if (flowSessionCtor == null) { + // Try to fallback to a Party constructor + val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true } + if (partyCtor == null) { + throw IllegalArgumentException("$responder must have a constructor accepting a ${FlowSession::class.java.name}") + } else { + log.warn("Installing flow factory for $responder accepting a ${Party::class.java.simpleName}, which is deprecated. " + + "It should accept a ${FlowSession::class.java.simpleName} instead") + } + { flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) } + } else { + { flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) } + } + val (version, classWithAnnotation) = initiator.flowVersionAndInitiatingClass + require(classWithAnnotation == initiator) { + "${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiator.name}" + } + val flowFactory = InitiatedFlowFactory.CorDapp(version, responder.appName, ctor) + registerInitiatedFlowFactory(initiator, flowFactory, responder) + log.info("Registered ${initiator.name} to initiate ${responder.name} (version $version)") + } + + private fun > registerInitiatedFlowFactory(initiatingFlowClass: Class>, + flowFactory: InitiatedFlowFactory, + initiatedFlowClass: Class?) { + + check(flowFactory !is InitiatedFlowFactory.Core) { "This should only be used for Cordapp flows" } + val listOfFlowsForInitiator = flowFactories.computeIfAbsent(initiatingFlowClass) { mutableListOf() } + if (listOfFlowsForInitiator.isNotEmpty() && listOfFlowsForInitiator.first().type == FlowType.CORE) { + throw IllegalStateException("Attempting to register over an existing platform flow: $initiatingFlowClass") + } + synchronized(listOfFlowsForInitiator) { + val flowToAdd = RegisteredFlowContainer(initiatingFlowClass, initiatedFlowClass, flowFactory, FlowType.CORDAPP) + val flowWeightComparator = FlowWeightComparator(initiatingFlowClass, flowOverrides) + listOfFlowsForInitiator.add(flowToAdd) + listOfFlowsForInitiator.sortWith(flowWeightComparator) + if (listOfFlowsForInitiator.size > 1) { + log.warn("Multiple flows are registered for InitiatingFlow: $initiatingFlowClass, currently using: ${listOfFlowsForInitiator.first().initiatedFlowClass}") + } + } + + } + + // TODO Harmonise use of these methods - 99% of invocations come from tests. + @Synchronized + override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, initiatedFlowClass: KClass>?, flowFactory: (FlowSession) -> FlowLogic<*>) { + registerInitiatedCoreFlowFactory(initiatingFlowClass, initiatedFlowClass, InitiatedFlowFactory.Core(flowFactory)) + } + + @Synchronized + override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, flowFactory: (FlowSession) -> FlowLogic<*>) { + registerInitiatedCoreFlowFactory(initiatingFlowClass, null, InitiatedFlowFactory.Core(flowFactory)) + } + + @Synchronized + override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass>, initiatedFlowClass: KClass>?, flowFactory: InitiatedFlowFactory.Core>) { + require(initiatingFlowClass.java.flowVersionAndInitiatingClass.first == 1) { + "${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version" + } + flowFactories.computeIfAbsent(initiatingFlowClass.java) { mutableListOf() }.add( + RegisteredFlowContainer( + initiatingFlowClass.java, + initiatedFlowClass?.java, + flowFactory, + FlowType.CORE) + ) + log.debug { "Installed core flow ${initiatingFlowClass.java.name}" } + } + + // To verify the integrity of the current state, it is important that the tip of the responders is a unique weight + // if there are multiple flows with the same weight as the tip, it means that it is impossible to reliably pick one as the responder + private fun validateInvariants(toValidate: List) { + val currentTip = toValidate.first() + val flowWeightComparator = FlowWeightComparator(currentTip.initiatingFlowClass, flowOverrides) + val equalWeightAsCurrentTip = toValidate.map { flowWeightComparator.compare(currentTip, it) to it }.filter { it.first == 0 }.map { it.second } + if (equalWeightAsCurrentTip.size > 1) { + val message = "Unable to determine which flow to use when responding to: ${currentTip.initiatingFlowClass.canonicalName}. ${equalWeightAsCurrentTip.map { it.initiatedFlowClass!!.canonicalName }} are all registered with equal weight." + throw IllegalStateException(message) + } + } + + @Synchronized + override fun validateRegistrations() { + flowFactories.values.forEach { + validateInvariants(it) + } + } + + private enum class FlowType { + CORE, CORDAPP + } + + private data class RegisteredFlowContainer(val initiatingFlowClass: Class>, + val initiatedFlowClass: Class>?, + val flowFactory: InitiatedFlowFactory>, + val type: FlowType) + + // this is used to sort the responding flows in order of "importance" + // the logic is as follows + // IF responder is a specific lambda (like for notary implementations / testing code) always return that responder + // ELSE IF responder is present in the overrides list, always return that responder + // ELSE compare responding flows by their depth from FlowLogic, always return the flow which is most specific (IE, has the most hops to FlowLogic) + private open class FlowWeightComparator(val initiatingFlowClass: Class>, val flowOverrides: Map) : Comparator { + + override fun compare(o1: NodeFlowManager.RegisteredFlowContainer, o2: NodeFlowManager.RegisteredFlowContainer): Int { + if (o1.initiatedFlowClass == null && o2.initiatedFlowClass != null) { + return Int.MIN_VALUE + } + if (o1.initiatedFlowClass != null && o2.initiatedFlowClass == null) { + return Int.MAX_VALUE + } + + if (o1.initiatedFlowClass == null && o2.initiatedFlowClass == null) { + return 0 + } + + val hopsTo1 = calculateHopsToFlowLogic(initiatingFlowClass, o1.initiatedFlowClass!!) + val hopsTo2 = calculateHopsToFlowLogic(initiatingFlowClass, o2.initiatedFlowClass!!) + return hopsTo1.compareTo(hopsTo2) * -1 + } + + private fun calculateHopsToFlowLogic(initiatingFlowClass: Class>, + initiatedFlowClass: Class>): Int { + + val overriddenClassName = flowOverrides[initiatingFlowClass.canonicalName] + return if (overriddenClassName == initiatedFlowClass.canonicalName) { + Int.MAX_VALUE + } else { + var currentClass: Class<*> = initiatedFlowClass + var count = 0 + while (currentClass != FlowLogic::class.java) { + currentClass = currentClass.superclass + count++ + } + count; + } + } + + } +} + +private fun Iterable>.toMutableMap(): MutableMap { + return this.toMap(HashMap()) +} diff --git a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt index 3b86147c4e..9d00e83a28 100644 --- a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt +++ b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt @@ -4,10 +4,12 @@ import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowSession sealed class InitiatedFlowFactory> { + protected abstract val factory: (FlowSession) -> F fun createFlow(initiatingFlowSession: FlowSession): F = factory(initiatingFlowSession) data class Core>(override val factory: (FlowSession) -> F) : InitiatedFlowFactory() + data class CorDapp>(val flowVersion: Int, val appName: String, override val factory: (FlowSession) -> F) : InitiatedFlowFactory() diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index c6cad69913..27523d9ccb 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -43,6 +43,7 @@ import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.config.* import net.corda.node.services.messaging.* import net.corda.node.services.rpc.ArtemisRpcBroker +import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.utilities.* import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.INTERNAL_SHELL_USER import net.corda.nodeapi.internal.ShutdownHook @@ -56,7 +57,6 @@ import org.apache.commons.lang.SystemUtils import org.h2.jdbc.JdbcSQLException import org.slf4j.Logger import org.slf4j.LoggerFactory -import rx.Observable import rx.Scheduler import rx.schedulers.Schedulers import java.net.BindException @@ -72,8 +72,7 @@ import kotlin.system.exitProcess class NodeWithInfo(val node: Node, val info: NodeInfo) { val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by node.services, FlowStarter by node.flowStarter {} fun dispose() = node.stop() - fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable = - node.registerInitiatedFlow(node.smm, initiatedFlowClass) + fun > registerInitiatedFlow(initiatedFlowClass: Class) = node.registerInitiatedFlow(node.smm, initiatedFlowClass) } /** @@ -85,12 +84,14 @@ class NodeWithInfo(val node: Node, val info: NodeInfo) { open class Node(configuration: NodeConfiguration, versionInfo: VersionInfo, private val initialiseSerialization: Boolean = true, + flowManager: FlowManager = NodeFlowManager(configuration.flowOverrides), cacheFactoryPrototype: BindableNamedCacheFactory = DefaultNamedCacheFactory() ) : AbstractNode( configuration, createClock(configuration), cacheFactoryPrototype, versionInfo, + flowManager, // Under normal (non-test execution) it will always be "1" AffinityExecutor.ServiceAffinityExecutor("Node thread-${sameVmNodeCounter.incrementAndGet()}", 1) ) { @@ -202,7 +203,8 @@ open class Node(configuration: NodeConfiguration, return P2PMessagingClient( config = configuration, versionInfo = versionInfo, - serverAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("localhost", configuration.p2pAddress.port), + serverAddress = configuration.messagingServerAddress + ?: NetworkHostAndPort("localhost", configuration.p2pAddress.port), nodeExecutor = serverThread, database = database, networkMap = networkMapCache, @@ -228,7 +230,8 @@ open class Node(configuration: NodeConfiguration, } val messageBroker = if (!configuration.messagingServerExternal) { - val brokerBindAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port) + val brokerBindAddress = configuration.messagingServerAddress + ?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port) ArtemisMessagingServer(configuration, brokerBindAddress, networkParameters.maxMessageSize) } else { null @@ -442,7 +445,7 @@ open class Node(configuration: NodeConfiguration, }.build().start() } - private fun registerNewRelicReporter (registry: MetricRegistry) { + private fun registerNewRelicReporter(registry: MetricRegistry) { log.info("Registering New Relic JMX Reporter:") val reporter = NewRelicReporter.forRegistry(registry) .name("New Relic Reporter") @@ -504,4 +507,8 @@ open class Node(configuration: NodeConfiguration, log.info("Shutdown complete") } + + fun > registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class) { + this.flowManager.registerInitiatedFlow(initiatedFlowClass) + } } diff --git a/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt b/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt index ac1badeafe..b1ba5f9f29 100644 --- a/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt +++ b/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt @@ -19,7 +19,6 @@ import net.corda.core.serialization.SerializeAsToken import net.corda.core.utilities.contextLogger import net.corda.node.VersionInfo import net.corda.node.cordapp.CordappLoader -import net.corda.node.internal.classloading.requireAnnotation import net.corda.nodeapi.internal.coreContractClasses import net.corda.serialization.internal.DefaultWhitelist import org.apache.commons.collections4.map.LRUMap @@ -148,17 +147,6 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths: private fun findInitiatedFlows(scanResult: RestrictedScanResult): List>> { return scanResult.getClassesWithAnnotation(FlowLogic::class, InitiatedBy::class) - // First group by the initiating flow class in case there are multiple mappings - .groupBy { it.requireAnnotation().value.java } - .map { (initiatingFlow, initiatedFlows) -> - val sorted = initiatedFlows.sortedWith(FlowTypeHierarchyComparator(initiatingFlow)) - if (sorted.size > 1) { - logger.warn("${initiatingFlow.name} has been specified as the inititating flow by multiple flows " + - "in the same type hierarchy: ${sorted.joinToString { it.name }}. Choosing the most " + - "specific sub-type for registration: ${sorted[0].name}.") - } - sorted[0] - } } private fun Class>.isUserInvokable(): Boolean { @@ -209,17 +197,7 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths: } } - private class FlowTypeHierarchyComparator(val initiatingFlow: Class>) : Comparator>> { - override fun compare(o1: Class>, o2: Class>): Int { - return when { - o1 == o2 -> 0 - o1.isAssignableFrom(o2) -> 1 - o2.isAssignableFrom(o1) -> -1 - else -> throw IllegalArgumentException("${initiatingFlow.name} has been specified as the initiating flow by " + - "both ${o1.name} and ${o2.name}") - } - } - } + private fun loadClass(className: String, type: KClass): Class? { return try { diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index eab762dcdd..3073d7574a 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -76,6 +76,7 @@ interface NodeConfiguration { val p2pSslOptions: MutualSslConfiguration val cordappDirectories: List + val flowOverrides: FlowOverrideConfig? fun validate(): List @@ -97,6 +98,9 @@ interface NodeConfiguration { } } +data class FlowOverrideConfig(val overrides: List = listOf()) +data class FlowOverride(val initiator: String, val responder: String) + /** * Currently registered JMX Reporters */ @@ -210,7 +214,8 @@ data class NodeConfigurationImpl( override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS, override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS, override val cordappDirectories: List = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT), - override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA + override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA, + override val flowOverrides: FlowOverrideConfig? ) : NodeConfiguration { companion object { private val logger = loggerFor() diff --git a/node/src/test/kotlin/net/corda/node/internal/FlowRegistrationTest.kt b/node/src/test/kotlin/net/corda/node/internal/FlowRegistrationTest.kt index e6a9b5fd3e..8b59037248 100644 --- a/node/src/test/kotlin/net/corda/node/internal/FlowRegistrationTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/FlowRegistrationTest.kt @@ -12,7 +12,6 @@ import net.corda.testing.core.singleIdentity import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNodeParameters import net.corda.testing.node.StartedMockNode -import org.assertj.core.api.Assertions.assertThatIllegalStateException import org.junit.After import org.junit.Before import org.junit.Test @@ -39,15 +38,15 @@ class FlowRegistrationTest { } @Test - fun `startup fails when two flows initiated by the same flow are registered`() { + fun `succeeds when a subclass of a flow initiated by the same flow is registered`() { // register the same flow twice to invoke the error without causing errors in other tests - responder.registerInitiatedFlow(Responder::class.java) - assertThatIllegalStateException().isThrownBy { responder.registerInitiatedFlow(Responder::class.java) } + responder.registerInitiatedFlow(Responder1::class.java) + responder.registerInitiatedFlow(Responder1Subclassed::class.java) } @Test fun `a single initiated flow can be registered without error`() { - responder.registerInitiatedFlow(Responder::class.java) + responder.registerInitiatedFlow(Responder1::class.java) val result = initiator.startFlow(Initiator(responder.info.singleIdentity())) mockNetwork.runNetwork() assertNotNull(result.get()) @@ -63,7 +62,38 @@ class Initiator(val party: Party) : FlowLogic() { } @InitiatedBy(Initiator::class) -private class Responder(val session: FlowSession) : FlowLogic() { +private open class Responder1(val session: FlowSession) : FlowLogic() { + open fun getPayload(): String { + return "whats up" + } + + @Suspendable + override fun call() { + session.receive().unwrap { it } + session.send("What's up") + } +} + +@InitiatedBy(Initiator::class) +private open class Responder2(val session: FlowSession) : FlowLogic() { + open fun getPayload(): String { + return "whats up" + } + + @Suspendable + override fun call() { + session.receive().unwrap { it } + session.send("What's up") + } +} + +@InitiatedBy(Initiator::class) +private class Responder1Subclassed(session: FlowSession) : Responder1(session) { + + override fun getPayload(): String { + return "im subclassed! that's what's up!" + } + @Suspendable override fun call() { session.receive().unwrap { it } diff --git a/node/src/test/kotlin/net/corda/node/internal/NodeFlowManagerTest.kt b/node/src/test/kotlin/net/corda/node/internal/NodeFlowManagerTest.kt new file mode 100644 index 0000000000..25722ef295 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/internal/NodeFlowManagerTest.kt @@ -0,0 +1,110 @@ +package net.corda.node.internal + +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.node.services.config.FlowOverride +import net.corda.node.services.config.FlowOverrideConfig +import org.hamcrest.CoreMatchers.`is` +import org.hamcrest.CoreMatchers.instanceOf +import org.junit.Assert +import org.junit.Test +import org.mockito.Mockito +import java.lang.IllegalStateException + +private val marker = "This is a special marker" + +class NodeFlowManagerTest { + + @InitiatingFlow + class Init : FlowLogic() { + override fun call() { + TODO("not implemented") + } + } + + @InitiatedBy(Init::class) + open class Resp(val otherSesh: FlowSession) : FlowLogic() { + override fun call() { + TODO("not implemented") + } + + } + + @InitiatedBy(Init::class) + class Resp2(val otherSesh: FlowSession) : FlowLogic() { + override fun call() { + TODO("not implemented") + } + + } + + @InitiatedBy(Init::class) + open class RespSub(sesh: FlowSession) : Resp(sesh) { + override fun call() { + TODO("not implemented") + } + + } + + @InitiatedBy(Init::class) + class RespSubSub(sesh: FlowSession) : RespSub(sesh) { + override fun call() { + TODO("not implemented") + } + + } + + + @Test(expected = IllegalStateException::class) + fun `should fail to validate if more than one registration with equal weight`() { + val nodeFlowManager = NodeFlowManager() + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java) + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java) + nodeFlowManager.validateRegistrations() + } + + @Test() + fun `should allow registration of flows with different weights`() { + val nodeFlowManager = NodeFlowManager() + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java) + nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java) + nodeFlowManager.validateRegistrations() + val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!! + val flow = factory.createFlow(Mockito.mock(FlowSession::class.java)) + Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java))) + } + + @Test() + fun `should allow updating of registered responder at runtime`() { + val nodeFlowManager = NodeFlowManager() + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java) + nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java) + nodeFlowManager.validateRegistrations() + var factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!! + var flow = factory.createFlow(Mockito.mock(FlowSession::class.java)) + Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java))) + // update + nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java) + nodeFlowManager.validateRegistrations() + + factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!! + flow = factory.createFlow(Mockito.mock(FlowSession::class.java)) + Assert.assertThat(flow, `is`(instanceOf(RespSubSub::class.java))) + } + + @Test + fun `should allow an override to be specified`() { + val nodeFlowManager = NodeFlowManager(FlowOverrideConfig(listOf(FlowOverride(Init::class.qualifiedName!!, Resp::class.qualifiedName!!)))) + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java) + nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java) + nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java) + nodeFlowManager.validateRegistrations() + + val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!! + val flow = factory.createFlow(Mockito.mock(FlowSession::class.java)) + + Assert.assertThat(flow, `is`(instanceOf(Resp::class.java))) + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt b/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt index 48a266ce31..fd2e1db83f 100644 --- a/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt @@ -149,7 +149,7 @@ class NodeTest { } } - private fun createConfig(nodeName: CordaX500Name): NodeConfiguration { + private fun createConfig(nodeName: CordaX500Name): NodeConfigurationImpl { val fakeAddress = NetworkHostAndPort("0.1.2.3", 456) return NodeConfigurationImpl( baseDirectory = temporaryFolder.root.toPath(), @@ -167,7 +167,8 @@ class NodeTest { flowTimeout = FlowTimeoutConfiguration(timeout = Duration.ZERO, backoffBase = 1.0, maxRestartCount = 1), rpcSettings = NodeRpcSettings(address = fakeAddress, adminAddress = null, ssl = null), messagingServerAddress = null, - notary = null + notary = null, + flowOverrides = FlowOverrideConfig(listOf()) ) } diff --git a/node/src/test/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoaderTest.kt b/node/src/test/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoaderTest.kt index 176bbfb5d8..9eba3d64ef 100644 --- a/node/src/test/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoaderTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoaderTest.kt @@ -56,7 +56,7 @@ class JarScanningCordappLoaderTest { val actualCordapp = loader.cordapps.single() assertThat(actualCordapp.contractClassNames).isEqualTo(listOf(isolatedContractId)) - assertThat(actualCordapp.initiatedFlows.single().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor") + assertThat(actualCordapp.initiatedFlows.first().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor") assertThat(actualCordapp.rpcFlows).isEmpty() assertThat(actualCordapp.schedulableFlows).isEmpty() assertThat(actualCordapp.services).isEmpty() @@ -74,7 +74,7 @@ class JarScanningCordappLoaderTest { assertThat(loader.cordapps).isNotEmpty val actualCordapp = loader.cordapps.single { !it.initiatedFlows.isEmpty() } - assertThat(actualCordapp.initiatedFlows).first().hasSameClassAs(DummyFlow::class.java) + assertThat(actualCordapp.initiatedFlows.first()).hasSameClassAs(DummyFlow::class.java) assertThat(actualCordapp.rpcFlows).first().hasSameClassAs(DummyRPCFlow::class.java) assertThat(actualCordapp.schedulableFlows).first().hasSameClassAs(DummySchedulableFlow::class.java) } diff --git a/node/src/test/kotlin/net/corda/node/services/config/NodeConfigurationImplTest.kt b/node/src/test/kotlin/net/corda/node/services/config/NodeConfigurationImplTest.kt index 27247b8eb5..a9e7da8a92 100644 --- a/node/src/test/kotlin/net/corda/node/services/config/NodeConfigurationImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/config/NodeConfigurationImplTest.kt @@ -172,8 +172,8 @@ class NodeConfigurationImplTest { val errors = configuration.validate() - assertThat(errors).hasOnlyOneElementSatisfying { - error -> error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously") + assertThat(errors).hasOnlyOneElementSatisfying { error -> + error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously") } } @@ -268,7 +268,8 @@ class NodeConfigurationImplTest { noLocalShell = false, rpcSettings = rpcSettings, crlCheckSoftFail = true, - tlsCertCrlDistPoint = null + tlsCertCrlDistPoint = null, + flowOverrides = FlowOverrideConfig(listOf()) ) } } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 075d324e9f..d1517dc318 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -26,7 +26,6 @@ import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker.Change import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap -import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.persistence.checkpoints import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState @@ -116,7 +115,7 @@ class FlowFrameworkTests { @Test fun `exception while fiber suspended`() { - bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } val flow = ReceiveFlow(bob) val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl // Before the flow runs change the suspend action to throw an exception @@ -134,7 +133,7 @@ class FlowFrameworkTests { @Test fun `both sides do a send as their first IO request`() { - bobNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) } + bobNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) } aliceNode.services.startFlow(PingPongFlow(bob, 10L)) mockNet.runNetwork() @@ -151,7 +150,7 @@ class FlowFrameworkTests { @Test fun `other side ends before doing expected send`() { - bobNode.registerFlowFactory(ReceiveFlow::class) { NoOpFlow() } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow() } val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { @@ -161,7 +160,7 @@ class FlowFrameworkTests { @Test fun `receiving unexpected session end before entering sendAndReceive`() { - bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } + bobNode.registerCordappFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } val sessionEndReceived = Semaphore(0) receivedSessionMessagesObservable().filter { it.message is ExistingSessionMessage && it.message.payload === EndSessionMessage @@ -176,7 +175,7 @@ class FlowFrameworkTests { @Test fun `FlowException thrown on other side`() { - val erroringFlow = bobNode.registerFlowFactory(ReceiveFlow::class) { + val erroringFlow = bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps } @@ -240,7 +239,7 @@ class FlowFrameworkTests { } } - bobNode.registerFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") } + bobNode.registerCordappFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") } val resultFuture = aliceNode.services.startFlow(RetryOnExceptionFlow(bob)).resultFuture mockNet.runNetwork() assertThat(resultFuture.getOrThrow()).isEqualTo("Hello") @@ -248,7 +247,7 @@ class FlowFrameworkTests { @Test fun `serialisation issue in counterparty`() { - bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) } val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { @@ -258,7 +257,7 @@ class FlowFrameworkTests { @Test fun `FlowException has non-serialisable object`() { - bobNode.registerFlowFactory(ReceiveFlow::class) { + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { NonSerialisableFlowException(NonSerialisableData(1)) } } val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture @@ -275,7 +274,7 @@ class FlowFrameworkTests { .addCommand(dummyCommand(alice.owningKey)) val stx = aliceNode.services.signInitialTransaction(ptx) - val committerFiber = aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) { + val committerFiber = aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) }.map { it.stateMachine }.map { uncheckedCast, FlowStateMachine>(it) } val waiterStx = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture @@ -290,7 +289,7 @@ class FlowFrameworkTests { .addCommand(dummyCommand()) val stx = aliceNode.services.signInitialTransaction(ptx) - aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) { + aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) { throw Exception("Error") } } val waiter = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture @@ -307,7 +306,7 @@ class FlowFrameworkTests { .addCommand(dummyCommand(alice.owningKey)) val stx = aliceNode.services.signInitialTransaction(ptx) - aliceNode.registerFlowFactory(VaultQueryFlow::class) { + aliceNode.registerCordappFlowFactory(VaultQueryFlow::class) { WaitingFlows.Committer(it) } val result = bobNode.services.startFlow(VaultQueryFlow(stx, alice)).resultFuture @@ -318,7 +317,7 @@ class FlowFrameworkTests { @Test fun `customised client flow`() { - val receiveFlowFuture = bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) } + val receiveFlowFuture = bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) } aliceNode.services.startFlow(CustomSendFlow("Hello", bob)).resultFuture mockNet.runNetwork() assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello") @@ -333,7 +332,7 @@ class FlowFrameworkTests { @Test fun `upgraded initiating flow`() { - bobNode.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) } + bobNode.registerCordappFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) } val result = aliceNode.services.startFlow(UpgradedFlow(bob)).resultFuture mockNet.runNetwork() assertThat(receivedSessionMessages).startsWith( @@ -347,7 +346,7 @@ class FlowFrameworkTests { @Test fun `upgraded initiated flow`() { - bobNode.registerFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) } + bobNode.registerCordappFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) } val initiatingFlow = SendFlow("Old initiating", bob) val flowInfo = aliceNode.services.startFlow(initiatingFlow).resultFuture mockNet.runNetwork() @@ -387,7 +386,7 @@ class FlowFrameworkTests { @Test fun `single inlined sub-flow`() { - bobNode.registerFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) } + bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) } val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture mockNet.runNetwork() assertThat(result.getOrThrow()).isEqualTo("HelloHello") @@ -395,7 +394,7 @@ class FlowFrameworkTests { @Test fun `double inlined sub-flow`() { - bobNode.registerFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) } + bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) } val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture mockNet.runNetwork() assertThat(result.getOrThrow()).isEqualTo("HelloHello") @@ -403,7 +402,7 @@ class FlowFrameworkTests { @Test fun `non-FlowException thrown on other side`() { - val erroringFlowFuture = bobNode.registerFlowFactory(ReceiveFlow::class) { + val erroringFlowFuture = bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { Exception("evil bug!") } } val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps } @@ -507,8 +506,8 @@ class FlowFrameworkTripartyTests { @Test fun `sending to multiple parties`() { - bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - charlieNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } val payload = "Hello World" aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) mockNet.runNetwork() @@ -538,8 +537,8 @@ class FlowFrameworkTripartyTests { fun `receiving from multiple parties`() { val bobPayload = "Test 1" val charliePayload = "Test 2" - bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } - charlieNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() aliceNode.services.startFlow(multiReceiveFlow) aliceNode.internals.acceptableLiveFiberCountOnStop = 1 @@ -564,8 +563,8 @@ class FlowFrameworkTripartyTests { @Test fun `FlowException only propagated to parent`() { - charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } - bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java) @@ -577,9 +576,9 @@ class FlowFrameworkTripartyTests { // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move // onto Charlie which will throw the exception val node2Fiber = bobNode - .registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } + .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } .map { it.stateMachine } - charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl mockNet.runNetwork() @@ -630,6 +629,8 @@ class FlowFrameworkPersistenceTests { private lateinit var notaryIdentity: Party private lateinit var alice: Party private lateinit var bob: Party + private lateinit var aliceFlowManager: MockNodeFlowManager + private lateinit var bobFlowManager: MockNodeFlowManager @Before fun start() { @@ -637,8 +638,11 @@ class FlowFrameworkPersistenceTests { cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), servicePeerAllocationStrategy = RoundRobin() ) - aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) - bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + aliceFlowManager = MockNodeFlowManager() + bobFlowManager = MockNodeFlowManager() + + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager)) receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } @@ -664,7 +668,7 @@ class FlowFrameworkPersistenceTests { @Test fun `flow restarted just after receiving payload`() { - bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } aliceNode.services.startFlow(SendFlow("Hello", bob)) // We push through just enough messages to get only the payload sent @@ -679,7 +683,7 @@ class FlowFrameworkPersistenceTests { @Test fun `flow loaded from checkpoint will respond to messages from before start`() { - aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } + aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow val restoredFlow = bobNode.restartAndGetRestoredFlow() assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") @@ -694,7 +698,7 @@ class FlowFrameworkPersistenceTests { var sentCount = 0 mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val secondFlow = charlieNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } + val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } mockNet.runNetwork() val charlie = charlieNode.info.singleIdentity() @@ -802,23 +806,14 @@ private infix fun TestStartedNode.sent(message: SessionMessage): Pair.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { - val isPayloadTransfer: Boolean get() = - message is ExistingSessionMessage && message.payload is DataSessionMessage || - message is InitialSessionMessage && message.firstPayload != null + val isPayloadTransfer: Boolean + get() = + message is ExistingSessionMessage && message.payload is DataSessionMessage || + message is InitialSessionMessage && message.firstPayload != null + override fun toString(): String = "$from sent $message to $to" } -private inline fun > TestStartedNode.registerFlowFactory( - initiatingFlowClass: KClass>, - initiatedFlowVersion: Int = 1, - noinline flowFactory: (FlowSession) -> P): CordaFuture

{ - val observable = registerFlowFactory( - initiatingFlowClass.java, - InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), - P::class.java, - track = true) - return observable.toFuture() -} private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) @@ -1061,7 +1056,8 @@ private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val ot constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession) @Suspendable - override fun call(): Any = (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive(payload).unwrap { it } + override fun call(): Any = (otherPartySession + ?: initiateFlow(otherParty)).sendAndReceive(payload).unwrap { it } } private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic() { @@ -1098,4 +1094,4 @@ private class ExceptionFlow(val exception: () -> E) : FlowLogic communicate(clientLogic: AbstractClientLogic, rebootClient: Boolean): FlowStateMachine { - server.registerFlowFactory(AbstractClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, serverRunning) }, ServerLogic::class.java, false) + server.registerCoreFlowFactory(AbstractClientLogic::class.java, ServerLogic::class.java, { ServerLogic(it, serverRunning) }, false) client.services.startFlow(clientLogic) while (!serverRunning.get()) mockNet.runNetwork(1) if (rebootClient) { diff --git a/samples/trader-demo/build.gradle b/samples/trader-demo/build.gradle index f8ffd4edf3..291f15d988 100644 --- a/samples/trader-demo/build.gradle +++ b/samples/trader-demo/build.gradle @@ -89,6 +89,20 @@ task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar', nodeTask, } extraConfig = ['h2Settings.address' : 'localhost:10017'] } + + //All other nodes should be using LoggingBuyerFlow as it is a subclass of BuyerFlow + node { + name "O=LoggingBank,L=London,C=GB" + p2pPort 10025 + cordapps = ["$project.group:finance:$corda_release_version"] + rpcUsers = ext.rpcUsers + rpcSettings { + address "localhost:10026" + adminAddress "localhost:10027" + } + extraConfig = ['h2Settings.address' : 'localhost:10035'] + flowOverride("net.corda.traderdemo.flow.SellerFlow", "net.corda.traderdemo.flow.BuyerFlow") + } } task integrationTest(type: Test, dependsOn: []) { diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt index 70cee50154..52a5a7c982 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt @@ -11,20 +11,17 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap import net.corda.finance.contracts.CommercialPaper -import net.corda.finance.contracts.getCashBalances import net.corda.finance.flows.TwoPartyTradeFlow -import net.corda.traderdemo.TransactionGraphSearch import java.util.* @InitiatedBy(SellerFlow::class) -class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic() { +open class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic() { object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset") - override val progressTracker: ProgressTracker = ProgressTracker(STARTING_BUY) @Suspendable - override fun call() { + override fun call(): SignedTransaction { progressTracker.currentStep = STARTING_BUY // Receive the offered amount and automatically agree to it (in reality this would be a longer negotiation) @@ -43,33 +40,6 @@ class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic() { println("Purchase complete - we are a happy customer! Final transaction is: " + "\n\n${Emoji.renderIfSupported(tradeTX.tx)}") - logIssuanceAttachment(tradeTX) - logBalance() - } - - private fun logBalance() { - val balances = serviceHub.getCashBalances().entries.map { "${it.key.currencyCode} ${it.value}" } - println("Remaining balance: ${balances.joinToString()}") - } - - private fun logIssuanceAttachment(tradeTX: SignedTransaction) { - // Find the original CP issuance. - // TODO: This is potentially very expensive, and requires transaction details we may no longer have once - // SGX is enabled. Should be replaced with including the attachment on all transactions involving - // the state. - val search = TransactionGraphSearch(serviceHub.validatedTransactions, listOf(tradeTX.tx), - TransactionGraphSearch.Query(withCommandOfType = CommercialPaper.Commands.Issue::class.java, - followInputsOfType = CommercialPaper.State::class.java)) - val cpIssuance = search.call().single() - - // Buyer will fetch the attachment from the seller automatically when it resolves the transaction. - - cpIssuance.attachments.first().let { - println(""" - -The issuance of the commercial paper came with an attachment. You can find it in the attachments directory: $it.jar - -${Emoji.renderIfSupported(cpIssuance)}""") - } + return tradeTX } } diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/LoggingBuyerFlow.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/LoggingBuyerFlow.kt new file mode 100644 index 0000000000..841b96f594 --- /dev/null +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/LoggingBuyerFlow.kt @@ -0,0 +1,48 @@ +package net.corda.traderdemo.flow + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.internal.Emoji +import net.corda.core.transactions.SignedTransaction +import net.corda.finance.contracts.CommercialPaper +import net.corda.finance.contracts.getCashBalances +import net.corda.traderdemo.TransactionGraphSearch + +@InitiatedBy(SellerFlow::class) +class LoggingBuyerFlow(private val otherSideSession: FlowSession) : BuyerFlow(otherSideSession) { + + @Suspendable + override fun call(): SignedTransaction { + val tradeTX = super.call() + logIssuanceAttachment(tradeTX) + logBalance() + return tradeTX + } + + private fun logBalance() { + val balances = serviceHub.getCashBalances().entries.map { "${it.key.currencyCode} ${it.value}" } + println("Remaining balance: ${balances.joinToString()}") + } + + private fun logIssuanceAttachment(tradeTX: SignedTransaction) { + // Find the original CP issuance. + // TODO: This is potentially very expensive, and requires transaction details we may no longer have once + // SGX is enabled. Should be replaced with including the attachment on all transactions involving + // the state. + val search = TransactionGraphSearch(serviceHub.validatedTransactions, listOf(tradeTX.tx), + TransactionGraphSearch.Query(withCommandOfType = CommercialPaper.Commands.Issue::class.java, + followInputsOfType = CommercialPaper.State::class.java)) + val cpIssuance = search.call().single() + + // Buyer will fetch the attachment from the seller automatically when it resolves the transaction. + + cpIssuance.attachments.first().let { + println(""" + +The issuance of the commercial paper came with an attachment. You can find it in the attachments directory: $it.jar + +${Emoji.renderIfSupported(cpIssuance)}""") + } + } +} diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt index 793c6fd8bb..14d7037398 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt @@ -148,7 +148,8 @@ data class NodeParameters( val maximumHeapSize: String = "512m", val logLevel: String? = null, val additionalCordapps: Collection = emptySet(), - val regenerateCordappsOnStart: Boolean = false + val regenerateCordappsOnStart: Boolean = false, + val flowOverrides: Map>, Class>> = emptyMap() ) { /** * Helper builder for configuring a [Node] from Java. diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/DriverDSL.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/DriverDSL.kt index bd1bcb464a..6eb2e7a71a 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/DriverDSL.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/DriverDSL.kt @@ -2,6 +2,7 @@ package net.corda.testing.driver import net.corda.core.DoNotImplement import net.corda.core.concurrent.CordaFuture +import net.corda.core.flows.FlowLogic import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.concurrent.map @@ -27,13 +28,14 @@ interface DriverDSL { * Returns the [NotaryHandle] for the single notary on the network. Throws if there are none or more than one. * @see notaryHandles */ - val defaultNotaryHandle: NotaryHandle get() { - return when (notaryHandles.size) { - 0 -> throw IllegalStateException("There are no notaries defined on the network") - 1 -> notaryHandles[0] - else -> throw IllegalStateException("There is more than one notary defined on the network") + val defaultNotaryHandle: NotaryHandle + get() { + return when (notaryHandles.size) { + 0 -> throw IllegalStateException("There are no notaries defined on the network") + 1 -> notaryHandles[0] + else -> throw IllegalStateException("There is more than one notary defined on the network") + } } - } /** * Returns the identity of the single notary on the network. Throws if there are none or more than one. @@ -47,11 +49,12 @@ interface DriverDSL { * @see defaultNotaryHandle * @see notaryHandles */ - val defaultNotaryNode: CordaFuture get() { - return defaultNotaryHandle.nodeHandles.map { - it.singleOrNull() ?: throw IllegalStateException("Default notary is not a single node") + val defaultNotaryNode: CordaFuture + get() { + return defaultNotaryHandle.nodeHandles.map { + it.singleOrNull() ?: throw IllegalStateException("Default notary is not a single node") + } } - } /** * Start a node. @@ -110,7 +113,8 @@ interface DriverDSL { startInSameProcess: Boolean? = defaultParameters.startInSameProcess, maximumHeapSize: String = defaultParameters.maximumHeapSize, additionalCordapps: Collection = defaultParameters.additionalCordapps, - regenerateCordappsOnStart: Boolean = defaultParameters.regenerateCordappsOnStart + regenerateCordappsOnStart: Boolean = defaultParameters.regenerateCordappsOnStart, + flowOverrides: Map>, Class>> = defaultParameters.flowOverrides ): CordaFuture /** diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/internal/DriverInternal.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/internal/DriverInternal.kt index cf85d45e15..927314827a 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/internal/DriverInternal.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/internal/DriverInternal.kt @@ -14,6 +14,7 @@ import net.corda.testing.driver.OutOfProcess import net.corda.testing.node.User import rx.Observable import java.nio.file.Path +import javax.validation.constraints.NotNull interface NodeHandleInternal : NodeHandle { val configuration: NodeConfiguration @@ -70,7 +71,11 @@ data class InProcessImpl( } override fun close() = stop() - override fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable = node.registerInitiatedFlow(initiatedFlowClass) + @NotNull + override fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable { + node.registerInitiatedFlow(initiatedFlowClass) + return Observable.empty() + } } val InProcess.internalServices: StartedNodeServices get() = services as StartedNodeServices diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt index 3c3ba57aa5..b3731f27ec 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt @@ -206,11 +206,8 @@ class StartedMockNode private constructor(private val node: TestStartedNode) { fun > registerResponderFlow(initiatingFlowClass: Class>, flowFactory: ResponderFlowFactory, responderFlowClass: Class): CordaFuture = - node.registerFlowFactory( - initiatingFlowClass, - InitiatedFlowFactory.CorDapp(flowVersion = 0, appName = "", factory = flowFactory::invoke), - responderFlowClass, true) - .toFuture() + + node.registerInitiatedFlow(initiatingFlowClass, responderFlowClass).toFuture() } /** @@ -240,13 +237,12 @@ interface ResponderFlowFactory> { */ inline fun > StartedMockNode.registerResponderFlow( initiatingFlowClass: Class>, - noinline flowFactory: (FlowSession) -> F): Future = - registerResponderFlow( - initiatingFlowClass, - object : ResponderFlowFactory { - override fun invoke(flowSession: FlowSession) = flowFactory(flowSession) - }, - F::class.java) + noinline flowFactory: (FlowSession) -> F): Future = registerResponderFlow( + initiatingFlowClass, + object : ResponderFlowFactory { + override fun invoke(flowSession: FlowSession) = flowFactory(flowSession) + }, + F::class.java) /** * A mock node brings up a suite of in-memory services in a fast manner suitable for unit testing. @@ -302,13 +298,13 @@ open class MockNetwork( constructor(cordappPackages: List, parameters: MockNetworkParameters = MockNetworkParameters()) : this(cordappPackages, defaultParameters = parameters) constructor( - cordappPackages: List, - defaultParameters: MockNetworkParameters = MockNetworkParameters(), - networkSendManuallyPumped: Boolean = defaultParameters.networkSendManuallyPumped, - threadPerNode: Boolean = defaultParameters.threadPerNode, - servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = defaultParameters.servicePeerAllocationStrategy, - notarySpecs: List = defaultParameters.notarySpecs, - networkParameters: NetworkParameters = defaultParameters.networkParameters + cordappPackages: List, + defaultParameters: MockNetworkParameters = MockNetworkParameters(), + networkSendManuallyPumped: Boolean = defaultParameters.networkSendManuallyPumped, + threadPerNode: Boolean = defaultParameters.threadPerNode, + servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = defaultParameters.servicePeerAllocationStrategy, + notarySpecs: List = defaultParameters.notarySpecs, + networkParameters: NetworkParameters = defaultParameters.networkParameters ) : this(emptyList(), defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters, cordappsForAllNodes = cordappsForPackages(cordappPackages)) private val internalMockNetwork: InternalMockNetwork = InternalMockNetwork(defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters = networkParameters, cordappsForAllNodes = cordappsForAllNodes) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt index c722058ba3..ad2173c181 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt @@ -8,6 +8,7 @@ import com.typesafe.config.ConfigValueFactory import net.corda.client.rpc.internal.createCordaRPCClientWithSslAndClassLoader import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.firstOf +import net.corda.core.flows.FlowLogic import net.corda.core.identity.CordaX500Name import net.corda.core.internal.* import net.corda.core.internal.concurrent.* @@ -37,7 +38,6 @@ import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.network.NetworkParametersCopier import net.corda.nodeapi.internal.network.NodeInfoFilesCopier import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme -import net.corda.testing.node.TestCordapp import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.DUMMY_BANK_A_NAME @@ -50,6 +50,7 @@ import net.corda.testing.internal.setGlobalSerialization import net.corda.testing.internal.stubs.CertificateStoreStubs import net.corda.testing.node.ClusterSpec import net.corda.testing.node.NotarySpec +import net.corda.testing.node.TestCordapp import net.corda.testing.node.User import net.corda.testing.node.internal.DriverDSLImpl.Companion.cordappsInCurrentAndAdditionalPackages import okhttp3.OkHttpClient @@ -213,7 +214,8 @@ class DriverDSLImpl( startInSameProcess: Boolean?, maximumHeapSize: String, additionalCordapps: Collection, - regenerateCordappsOnStart: Boolean + regenerateCordappsOnStart: Boolean, + flowOverrides: Map>, Class>> ): CordaFuture { val p2pAddress = portAllocation.nextHostAndPort() // TODO: Derive name from the full picked name, don't just wrap the common name @@ -230,7 +232,7 @@ class DriverDSLImpl( return registrationFuture.flatMap { networkMapAvailability.flatMap { // But starting the node proper does require the network map - startRegisteredNode(name, it, rpcUsers, verifierType, customOverrides, startInSameProcess, maximumHeapSize, p2pAddress, additionalCordapps, regenerateCordappsOnStart) + startRegisteredNode(name, it, rpcUsers, verifierType, customOverrides, startInSameProcess, maximumHeapSize, p2pAddress, additionalCordapps, regenerateCordappsOnStart, flowOverrides) } } } @@ -244,7 +246,8 @@ class DriverDSLImpl( maximumHeapSize: String = "512m", p2pAddress: NetworkHostAndPort = portAllocation.nextHostAndPort(), additionalCordapps: Collection = emptySet(), - regenerateCordappsOnStart: Boolean = false): CordaFuture { + regenerateCordappsOnStart: Boolean = false, + flowOverrides: Map>, Class>> = emptyMap()): CordaFuture { val rpcAddress = portAllocation.nextHostAndPort() val rpcAdminAddress = portAllocation.nextHostAndPort() val webAddress = portAllocation.nextHostAndPort() @@ -258,14 +261,16 @@ class DriverDSLImpl( "networkServices.networkMapURL" to compatibilityZone.networkMapURL().toString()) } + val flowOverrideConfig = flowOverrides.entries.map { FlowOverride(it.key.canonicalName, it.value.canonicalName) }.let { FlowOverrideConfig(it) } val overrides = configOf( - "myLegalName" to name.toString(), - "p2pAddress" to p2pAddress.toString(), + NodeConfiguration::myLegalName.name to name.toString(), + NodeConfiguration::p2pAddress.name to p2pAddress.toString(), "rpcSettings.address" to rpcAddress.toString(), "rpcSettings.adminAddress" to rpcAdminAddress.toString(), - "useTestClock" to useTestClock, - "rpcUsers" to if (users.isEmpty()) defaultRpcUserList else users.map { it.toConfig().root().unwrapped() }, - "verifierType" to verifierType.name + NodeConfiguration::useTestClock.name to useTestClock, + NodeConfiguration::rpcUsers.name to if (users.isEmpty()) defaultRpcUserList else users.map { it.toConfig().root().unwrapped() }, + NodeConfiguration::verifierType.name to verifierType.name, + NodeConfiguration::flowOverrides.name to flowOverrideConfig.toConfig().root().unwrapped() ) + czUrlConfig + customOverrides val config = NodeConfig(ConfigHelper.loadConfig( baseDirectory = baseDirectory(name), @@ -516,8 +521,7 @@ class DriverDSLImpl( localNetworkMap, spec.rpcUsers, spec.verifierType, - customOverrides = notaryConfig(clusterAddress) - ) + customOverrides = notaryConfig(clusterAddress)) // All other nodes will join the cluster val restNodeFutures = nodeNames.drop(1).map { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt index d60f9f6ee1..f280cee72d 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt @@ -30,6 +30,8 @@ import net.corda.core.utilities.seconds import net.corda.node.VersionInfo import net.corda.node.internal.AbstractNode import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.internal.NodeFlowManager +import net.corda.node.internal.cordapp.JarScanningCordappLoader import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices @@ -51,7 +53,6 @@ import net.corda.nodeapi.internal.config.User import net.corda.nodeapi.internal.network.NetworkParametersCopier import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig -import net.corda.testing.node.TestCordapp import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.setGlobalSerialization @@ -79,7 +80,8 @@ data class MockNodeArgs( val network: InternalMockNetwork, val id: Int, val entropyRoot: BigInteger, - val version: VersionInfo = MOCK_VERSION_INFO + val version: VersionInfo = MOCK_VERSION_INFO, + val flowManager: MockNodeFlowManager = MockNodeFlowManager() ) // TODO We don't need a parameters object as this is internal only @@ -89,7 +91,8 @@ data class InternalMockNodeParameters( val entropyRoot: BigInteger = BigInteger.valueOf(random63BitValue()), val configOverrides: (NodeConfiguration) -> Any? = {}, val version: VersionInfo = MOCK_VERSION_INFO, - val additionalCordapps: Collection? = null) { + val additionalCordapps: Collection? = null, + val flowManager: MockNodeFlowManager = MockNodeFlowManager()) { constructor(mockNodeParameters: MockNodeParameters) : this( mockNodeParameters.forcedID, mockNodeParameters.legalName, @@ -132,12 +135,10 @@ interface TestStartedNode { * starts up for all [FlowLogic] classes it finds which are annotated with [InitiatedBy]. * @return An [Observable] of the initiated flows started by counterparties. */ - fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable + fun > registerInitiatedFlow(initiatedFlowClass: Class, track: Boolean = false): Observable + + fun > registerInitiatedFlow(initiatingFlowClass: Class>, initiatedFlowClass: Class, track: Boolean = false): Observable - fun > registerFlowFactory(initiatingFlowClass: Class>, - flowFactory: InitiatedFlowFactory, - initiatedFlowClass: Class, - track: Boolean): Observable } open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParameters(), @@ -202,7 +203,8 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe */ val defaultNotaryIdentity: Party get() { - return defaultNotaryNode.info.legalIdentities.singleOrNull() ?: throw IllegalStateException("Default notary has multiple identities") + return defaultNotaryNode.info.legalIdentities.singleOrNull() + ?: throw IllegalStateException("Default notary has multiple identities") } /** @@ -270,11 +272,12 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe } } - open class MockNode(args: MockNodeArgs) : AbstractNode( + open class MockNode(args: MockNodeArgs, private val mockFlowManager: MockNodeFlowManager = args.flowManager) : AbstractNode( args.config, TestClock(Clock.systemUTC()), DefaultNamedCacheFactory(), args.version, + mockFlowManager, args.network.getServerThread(args.id), args.network.busyLatch ) { @@ -294,24 +297,28 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe override val rpcOps: CordaRPCOps, override val notaryService: NotaryService?) : TestStartedNode { - override fun > registerFlowFactory( - initiatingFlowClass: Class>, - flowFactory: InitiatedFlowFactory, - initiatedFlowClass: Class, - track: Boolean): Observable = - internals.internalRegisterFlowFactory(smm, initiatingFlowClass, flowFactory, initiatedFlowClass, track) - override fun dispose() = internals.stop() - override fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable = - internals.registerInitiatedFlow(smm, initiatedFlowClass) + override fun > registerInitiatedFlow(initiatedFlowClass: Class, track: Boolean): Observable { + internals.flowManager.registerInitiatedFlow(initiatedFlowClass) + return smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass) + } + + override fun > registerInitiatedFlow(initiatingFlowClass: Class>, initiatedFlowClass: Class, track: Boolean): Observable { + internals.flowManager.registerInitiatedFlow(initiatingFlowClass, initiatedFlowClass) + return smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass) + } + + } val mockNet = args.network val id = args.id + init { require(id >= 0) { "Node ID must be zero or positive, was passed: $id" } } + private val entropyRoot = args.entropyRoot var counter = entropyRoot override val log get() = staticLog @@ -333,7 +340,7 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe this, attachments, network as MockNodeMessagingService, - object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter { }, + object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter {}, nodeInfo, smm, database, @@ -417,8 +424,19 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe var acceptableLiveFiberCountOnStop: Int = 0 override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop + + fun > registerInitiatedFlowFactory(initiatingFlowClass: Class>, initiatedFlowClass: Class, factory: InitiatedFlowFactory, track: Boolean): Observable { + mockFlowManager.registerTestingFactory(initiatingFlowClass, factory) + return if (track) { + smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass) + } else { + Observable.empty() + } + } } + + fun createUnstartedNode(parameters: InternalMockNodeParameters = InternalMockNodeParameters()): MockNode { return createUnstartedNode(parameters, defaultFactory) } @@ -453,7 +471,7 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe val cordappDirectories = cordapps.map { TestCordappDirectories.getJarDirectory(it) }.distinct() doReturn(cordappDirectories).whenever(config).cordappDirectories - val node = nodeFactory(MockNodeArgs(config, this, id, parameters.entropyRoot, parameters.version)) + val node = nodeFactory(MockNodeArgs(config, this, id, parameters.entropyRoot, parameters.version, flowManager = parameters.flowManager)) _nodes += node if (start) { node.start() @@ -482,8 +500,10 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe */ @JvmOverloads fun runNetwork(rounds: Int = -1) { - check(!networkSendManuallyPumped) { "MockNetwork.runNetwork() should only be used when networkSendManuallyPumped == false. " + - "You can use MockNetwork.waitQuiescent() to wait for all the nodes to process all the messages on their queues instead." } + check(!networkSendManuallyPumped) { + "MockNetwork.runNetwork() should only be used when networkSendManuallyPumped == false. " + + "You can use MockNetwork.waitQuiescent() to wait for all the nodes to process all the messages on their queues instead." + } fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) } if (rounds == -1) { @@ -572,3 +592,17 @@ private fun mockNodeConfiguration(certificatesDirectory: Path): NodeConfiguratio doReturn(null).whenever(it).devModeOptions } } + +class MockNodeFlowManager : NodeFlowManager() { + val testingRegistrations = HashMap>, InitiatedFlowFactory<*>>() + override fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class>): InitiatedFlowFactory<*>? { + if (initiatedFlowClass in testingRegistrations) { + return testingRegistrations.get(initiatedFlowClass) + } + return super.getFlowFactoryForInitiatingFlow(initiatedFlowClass) + } + + fun registerTestingFactory(initiator: Class>, factory: InitiatedFlowFactory<*>) { + testingRegistrations.put(initiator, factory) + } +} \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt index dae33f8c4b..a86abe6b9d 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt @@ -10,7 +10,9 @@ import net.corda.core.internal.div import net.corda.core.node.NodeInfo import net.corda.core.utilities.getOrThrow import net.corda.node.VersionInfo +import net.corda.node.internal.FlowManager import net.corda.node.internal.Node +import net.corda.node.internal.NodeFlowManager import net.corda.node.internal.NodeWithInfo import net.corda.node.services.config.* import net.corda.nodeapi.internal.config.toConfig @@ -87,7 +89,8 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi fun startNode(legalName: CordaX500Name, platformVersion: Int = PLATFORM_VERSION, rpcUsers: List = emptyList(), - configOverrides: Map = emptyMap()): NodeWithInfo { + configOverrides: Map = emptyMap(), + flowManager: FlowManager = NodeFlowManager(FlowOverrideConfig())): NodeWithInfo { val baseDirectory = baseDirectory(legalName).createDirectories() val p2pAddress = configOverrides["p2pAddress"] ?: portAllocation.nextHostAndPort().toString() val config = ConfigHelper.loadConfig( @@ -103,7 +106,8 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi ) + configOverrides ) - val cordapps = cordappsForPackages(getCallerPackage(NodeBasedTest::class)?.let { cordappPackages + it } ?: cordappPackages) + val cordapps = cordappsForPackages(getCallerPackage(NodeBasedTest::class)?.let { cordappPackages + it } + ?: cordappPackages) val existingCorDappDirectoriesOption = if (config.hasPath(NodeConfiguration.cordappDirectoriesKey)) config.getStringList(NodeConfiguration.cordappDirectoriesKey) else emptyList() @@ -119,7 +123,7 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi } defaultNetworkParameters.install(baseDirectory) - val node = InProcessNode(parsedConfig, MOCK_VERSION_INFO.copy(platformVersion = platformVersion)) + val node = InProcessNode(parsedConfig, MOCK_VERSION_INFO.copy(platformVersion = platformVersion), flowManager = flowManager) val nodeInfo = node.start() val nodeWithInfo = NodeWithInfo(node, nodeInfo) nodes += nodeWithInfo @@ -145,7 +149,7 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi } } -class InProcessNode(configuration: NodeConfiguration, versionInfo: VersionInfo) : Node(configuration, versionInfo, false) { +class InProcessNode(configuration: NodeConfiguration, versionInfo: VersionInfo, flowManager: FlowManager = NodeFlowManager(configuration.flowOverrides)) : Node(configuration, versionInfo, false, flowManager = flowManager) { override fun start() : NodeInfo { check(isValidJavaVersion()) { "You are using a version of Java that is not supported (${SystemUtils.JAVA_VERSION}). Please upgrade to the latest version of Java 8." } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/MockCordappProvider.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/MockCordappProvider.kt index a6847b8fa2..1494a38715 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/MockCordappProvider.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/MockCordappProvider.kt @@ -3,7 +3,6 @@ package net.corda.testing.internal import net.corda.core.contracts.ContractClassName import net.corda.core.cordapp.Cordapp import net.corda.core.crypto.SecureHash -import net.corda.core.identity.Party import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER import net.corda.core.internal.cordapp.CordappImpl import net.corda.core.node.services.AttachmentId @@ -13,7 +12,6 @@ import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.testing.services.MockAttachmentStorage import java.nio.file.Paths import java.security.PublicKey -import java.util.* class MockCordappProvider( cordappLoader: CordappLoader,