ENT-2509 - Make @InitiatedBy flows overridable via node config (#3960)

* first attempt at a flowManager

fix test breakages

add testing around registering subclasses

make flowManager a param of MockNode

extract interface
rename methods

more work around overriding flows

more test fixes

add sample project showing how to use flowOverrides

rebase

* make smallest possible changes to AttachmentSerializationTest and ReceiveAllFlowTests

* add some comments about how flow manager weights flows

* address review comments
add documentation

* address more review comments
This commit is contained in:
Stefano Franz 2018-10-23 16:45:07 +01:00 committed by GitHub
parent f8ac35df25
commit 0919b01271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 930 additions and 324 deletions

View File

@ -1,4 +1,4 @@
gradlePluginsVersion=4.0.32 gradlePluginsVersion=4.0.33
kotlinVersion=1.2.71 kotlinVersion=1.2.71
# ***************************************************************# # ***************************************************************#
# When incrementing platformVersion make sure to update # # When incrementing platformVersion make sure to update #

View File

@ -1,7 +1,6 @@
package net.corda.core.contracts package net.corda.core.contracts
import net.corda.core.KeepForDJVM import net.corda.core.KeepForDJVM
import net.corda.core.identity.Party
import net.corda.core.internal.extractFile import net.corda.core.internal.extractFile
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import java.io.FileNotFoundException import java.io.FileNotFoundException

View File

@ -3,7 +3,6 @@ package net.corda.core.contracts
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.KeepForDJVM import net.corda.core.KeepForDJVM
import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint.isSatisfiedBy import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint.isSatisfiedBy
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.internal.AttachmentWithContext import net.corda.core.internal.AttachmentWithContext

View File

@ -43,7 +43,7 @@ data class CordappImpl(
*/ */
override val cordappClasses: List<String> = run { override val cordappClasses: List<String> = run {
val classList = rpcFlows + initiatedFlows + services + serializationWhitelists.map { javaClass } + notaryService val classList = rpcFlows + initiatedFlows + services + serializationWhitelists.map { javaClass } + notaryService
classList.mapNotNull { it?.name } + contractClassNames classList.mapNotNull { it?.name } + contractClassNames
} }
// TODO Why a seperate Info class and not just have the fields directly in CordappImpl? // TODO Why a seperate Info class and not just have the fields directly in CordappImpl?

View File

@ -5,7 +5,10 @@ import net.corda.core.CordaInternal
import net.corda.core.DeleteForDJVM import net.corda.core.DeleteForDJVM
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.* import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignableData
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ensureMinimumPlatformVersion import net.corda.core.internal.ensureMinimumPlatformVersion

View File

@ -1,10 +1,13 @@
package net.corda.core.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.concurrent.CordaFuture
import net.corda.core.toFuture
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.InitiatedFlowFactory
import net.corda.testing.node.internal.TestStartedNode import net.corda.testing.node.internal.TestStartedNode
import rx.Observable
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -34,20 +37,6 @@ class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic<Unit>() {
override fun call() = closure() override fun call() = closure()
} }
/**
* Allows to register a flow of type [R] against an initiating flow of type [I].
*/
inline fun <I : FlowLogic<*>, reified R : FlowLogic<*>> TestStartedNode.registerInitiatedFlow(initiatingFlowType: KClass<I>, crossinline construct: (session: FlowSession) -> R) {
registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true)
}
/**
* Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R].
*/
inline fun <I : FlowLogic<*>, reified R : Any> TestStartedNode.registerAnswer(initiatingFlowType: KClass<I>, value: R) {
registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true)
}
/** /**
* Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R]. * Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R].
*/ */
@ -112,4 +101,23 @@ inline fun <reified R : Any> FlowLogic<*>.receiveAll(session: FlowSession, varar
private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() { private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() {
require(this.size == this.toSet().size) { "A flow session can only appear once as argument." } require(this.size == this.toSet().size) { "A flow session can only appear once as argument." }
}
inline fun <reified P : FlowLogic<*>> TestStartedNode.registerCordappFlowFactory(
initiatingFlowClass: KClass<out FlowLogic<*>>,
initiatedFlowVersion: Int = 1,
noinline flowFactory: (FlowSession) -> P): CordaFuture<P> {
val observable = internals.registerInitiatedFlowFactory(
initiatingFlowClass.java,
P::class.java,
InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory),
track = true)
return observable.toFuture()
}
fun <T : FlowLogic<*>> TestStartedNode.registerCoreFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>,
initiatedFlowClass: Class<T>,
flowFactory: (FlowSession) -> T , track: Boolean): Observable<T> {
return this.internals.registerInitiatedFlowFactory(initiatingFlowClass, initiatedFlowClass, InitiatedFlowFactory.Core(flowFactory), track)
} }

View File

@ -2,16 +2,18 @@ package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.natpryce.hamkrest.assertion.assert import com.natpryce.hamkrest.assertion.assert
import net.corda.testing.internal.matchers.flow.willReturn
import net.corda.core.flows.mixins.WithMockNet import net.corda.core.flows.mixins.WithMockNet
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.matchers.flow.willReturn
import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.TestStartedNode
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.AfterClass import org.junit.AfterClass
import org.junit.Test import org.junit.Test
import kotlin.reflect.KClass
class ReceiveMultipleFlowTests : WithMockNet { class ReceiveMultipleFlowTests : WithMockNet {
@ -43,7 +45,7 @@ class ReceiveMultipleFlowTests : WithMockNet {
} }
} }
nodes[1].registerInitiatedFlow(initiatingFlow::class) { session -> nodes[1].registerCordappFlowFactory(initiatingFlow::class) { session ->
object : FlowLogic<Unit>() { object : FlowLogic<Unit>() {
@Suspendable @Suspendable
override fun call() { override fun call() {
@ -123,4 +125,15 @@ class ReceiveMultipleFlowTests : WithMockNet {
return double * string.length return double * string.length
} }
} }
}
private inline fun <reified T> TestStartedNode.registerAnswer(kClass: KClass<out FlowLogic<Any>>, value1: T) {
this.registerCordappFlowFactory(kClass) { session ->
object : FlowLogic<Unit>() {
@Suspendable
override fun call() {
session.send(value1!!)
}
}
}
} }

View File

@ -8,7 +8,6 @@ import net.corda.core.JarSignatureTestUtils.updateJar
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.CHARLIE_NAME
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.AfterClass import org.junit.AfterClass

View File

@ -3,16 +3,12 @@ package net.corda.core.serialization
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.Attachment import net.corda.core.contracts.Attachment
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.*
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.TestNoSecurityDataVendingFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FetchAttachmentsFlow import net.corda.core.internal.FetchAttachmentsFlow
import net.corda.core.internal.FetchDataFlow import net.corda.core.internal.FetchDataFlow
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.nodeapi.internal.persistence.currentDBSession import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
@ -151,11 +147,10 @@ class AttachmentSerializationTest {
} }
private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) { private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) {
server.registerFlowFactory( server.registerCordappFlowFactory(
ClientLogic::class.java, ClientLogic::class,
InitiatedFlowFactory.Core { ServerLogic(it, sendData) }, 1
ServerLogic::class.java, ) { ServerLogic(it, sendData) }
track = false)
client.services.startFlow(clientLogic) client.services.startFlow(clientLogic)
mockNet.runNetwork(rounds) mockNet.runNetwork(rounds)
} }

View File

@ -0,0 +1,141 @@
Configuring Responder Flows
===========================
A flow can be a fairly complex thing that interacts with many backend systems, and so it is likely that different users
of a specific CordApp will require differences in how flows interact with their specific infrastructure.
Corda supports this functionality by providing two mechanisms to modify the behaviour of apps in your node.
Subclassing a Flow
------------------
If you have a workflow which is mostly common, but also requires slight alterations in specific situations, most developers would be familiar
with refactoring into `Base` and `Sub` classes. A simple example is shown below.
java
~~~~
.. code-block:: java
@InitiatingFlow
public class Initiator extends FlowLogic<String> {
private final Party otherSide;
public Initiator(Party otherSide) {
this.otherSide = otherSide;
}
@Override
public String call() throws FlowException {
return initiateFlow(otherSide).receive(String.class).unwrap((it) -> it);
}
}
@InitiatedBy(Initiator.class)
public class BaseResponder extends FlowLogic<Void> {
private FlowSession counterpartySession;
public BaseResponder(FlowSession counterpartySession) {
super();
this.counterpartySession = counterpartySession;
}
@Override
public Void call() throws FlowException {
counterpartySession.send(getMessage());
return Void;
}
protected String getMessage() {
return "This Is the Legacy Responder";
}
}
public class SubResponder extends BaseResponder {
public SubResponder(FlowSession counterpartySession) {
super(counterpartySession);
}
@Override
protected String getMessage() {
return "This is the sub responder";
}
}
kotlin
~~~~~~
.. code-block:: kotlin
@InitiatedBy(Initiator::class)
open class BaseResponder(internal val otherSideSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
otherSideSession.send(getMessage())
}
protected open fun getMessage() = "This Is the Legacy Responder"
}
@InitiatedBy(Initiator::class)
class SubResponder(otherSideSession: FlowSession) : BaseResponder(otherSideSession) {
override fun getMessage(): String {
return "This is the sub responder"
}
}
Corda would detect that both ``BaseResponder`` and ``SubResponder`` are configured for responding to ``Initiator``.
Corda will then calculate the hops to ``FlowLogic`` and select the implementation which is furthest distance, ie: the most subclassed implementation.
In the above example, ``SubResponder`` would be selected as the default responder for ``Initiator``
.. note:: The flows do not need to be within the same CordApp, or package, therefore to customise a shared app you obtained from a third party, you'd write your own CorDapp that subclasses the first."
Overriding a flow via node configuration
----------------------------------------
Whilst the subclassing approach is likely to be useful for most applications, there is another mechanism to override this behaviour.
This would be useful if for example, a specific CordApp user requires such a different responder that subclassing an existing flow
would not be a good solution. In this case, it's possible to specify a hardcoded flow via the node configuration.
The configuration section is named ``flowOverrides`` and it accepts an array of ``overrides``
.. container:: codeset
.. code-block:: json
flowOverrides {
overrides=[
{
initiator="net.corda.Initiator"
responder="net.corda.BaseResponder"
}
]
}
The cordform plugin also provides a ``flowOverride`` method within the ``deployNodes`` block which can be used to override a flow. In the below example, we will override
the ``SubResponder`` with ``BaseResponder``
.. container:: codeset
.. code-block:: groovy
node {
name "O=Bank,L=London,C=GB"
p2pPort 10025
rpcUsers = ext.rpcUsers
rpcSettings {
address "localhost:10026"
adminAddress "localhost:10027"
}
extraConfig = ['h2Settings.address' : 'localhost:10035']
flowOverride("net.corda.Initiator", "net.corda.BaseResponder")
}
This will generate the corresponding ``flowOverrides`` section and place it in the configuration for that node.

View File

@ -0,0 +1,85 @@
package net.corda.node.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.testing.core.singleIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.internal.cordappForClasses
import org.hamcrest.CoreMatchers.`is`
import org.junit.Assert
import org.junit.Test
class FlowOverrideTests {
@StartableByRPC
@InitiatingFlow
class Ping(private val pongParty: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String {
val pongSession = initiateFlow(pongParty)
return pongSession.sendAndReceive<String>("PING").unwrap { it }
}
}
@InitiatedBy(Ping::class)
open class Pong(private val pingSession: FlowSession) : FlowLogic<Unit>() {
companion object {
val PONG = "PONG"
}
@Suspendable
override fun call() {
pingSession.send(PONG)
}
}
@InitiatedBy(Ping::class)
class Pong2(private val pingSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
pingSession.send("PONGPONG")
}
}
@InitiatedBy(Ping::class)
class Pongiest(private val pingSession: FlowSession) : Pong(pingSession) {
companion object {
val GORGONZOLA = "Gorgonzola"
}
@Suspendable
override fun call() {
pingSession.send(GORGONZOLA)
}
}
private val nodeAClasses = setOf(Ping::class.java,
Pong::class.java, Pongiest::class.java)
private val nodeBClasses = setOf(Ping::class.java, Pong::class.java)
@Test
fun `should use the most "specific" implementation of a responding flow`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) {
val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray()))).getOrThrow()
val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow()
Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pongiest.GORGONZOLA))
}
}
@Test
fun `should use the overriden implementation of a responding flow`() {
val flowOverrides = mapOf(Ping::class.java to Pong::class.java)
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) {
val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray())), flowOverrides = flowOverrides).getOrThrow()
val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow()
Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pong.PONG))
}
}
}

View File

@ -7,10 +7,11 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.testing.core.singleIdentity import net.corda.node.internal.NodeFlowManager
import net.corda.testing.node.internal.NodeBasedTest
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.singleIdentity
import net.corda.testing.node.internal.NodeBasedTest
import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.startFlow
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Test import org.junit.Test
@ -18,9 +19,10 @@ import org.junit.Test
class FlowVersioningTest : NodeBasedTest() { class FlowVersioningTest : NodeBasedTest() {
@Test @Test
fun `getFlowContext returns the platform version for core flows`() { fun `getFlowContext returns the platform version for core flows`() {
val bobFlowManager = NodeFlowManager()
val alice = startNode(ALICE_NAME, platformVersion = 2) val alice = startNode(ALICE_NAME, platformVersion = 2)
val bob = startNode(BOB_NAME, platformVersion = 3) val bob = startNode(BOB_NAME, platformVersion = 3, flowManager = bobFlowManager)
bob.node.installCoreFlow(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow) bobFlowManager.registerInitiatedCoreFlowFactory(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow)
val (alicePlatformVersionAccordingToBob, bobPlatformVersionAccordingToAlice) = alice.services.startFlow( val (alicePlatformVersionAccordingToBob, bobPlatformVersionAccordingToAlice) = alice.services.startFlow(
PretendInitiatingCoreFlow(bob.info.singleIdentity())).resultFuture.getOrThrow() PretendInitiatingCoreFlow(bob.info.singleIdentity())).resultFuture.getOrThrow()
assertThat(alicePlatformVersionAccordingToBob).isEqualTo(2) assertThat(alicePlatformVersionAccordingToBob).isEqualTo(2)
@ -45,4 +47,5 @@ class FlowVersioningTest : NodeBasedTest() {
@Suspendable @Suspendable
override fun call() = otherSideSession.send(otherSideSession.getCounterpartyFlowInfo().flowVersion) override fun call() = otherSideSession.send(otherSideSession.getCounterpartyFlowInfo().flowVersion)
} }
} }

View File

@ -30,7 +30,10 @@ import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.* import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.minutes
import net.corda.node.CordaClock import net.corda.node.CordaClock
import net.corda.node.SerialFilter import net.corda.node.SerialFilter
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
@ -99,13 +102,10 @@ import java.time.Clock
import java.time.Duration import java.time.Duration
import java.time.format.DateTimeParseException import java.time.format.DateTimeParseException
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.TimeUnit.SECONDS
import kotlin.collections.set
import kotlin.reflect.KClass
import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair
/** /**
@ -120,11 +120,11 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
val platformClock: CordaClock, val platformClock: CordaClock,
cacheFactoryPrototype: BindableNamedCacheFactory, cacheFactoryPrototype: BindableNamedCacheFactory,
protected val versionInfo: VersionInfo, protected val versionInfo: VersionInfo,
protected val flowManager: FlowManager,
protected val serverThread: AffinityExecutor.ServiceAffinityExecutor, protected val serverThread: AffinityExecutor.ServiceAffinityExecutor,
private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
protected abstract val log: Logger protected abstract val log: Logger
@Suppress("LeakingThis") @Suppress("LeakingThis")
private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this) private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this)
@ -211,7 +211,6 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
).tokenize().closeOnStop() ).tokenize().closeOnStop()
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>() private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val flowFactories = ConcurrentHashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
private val shutdownExecutor = Executors.newSingleThreadExecutor() private val shutdownExecutor = Executors.newSingleThreadExecutor()
protected abstract val transactionVerifierWorkerCount: Int protected abstract val transactionVerifierWorkerCount: Int
@ -237,7 +236,8 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
private var _started: S? = null private var _started: S? = null
private fun <T : Any> T.tokenize(): T { private fun <T : Any> T.tokenize(): T {
tokenizableServices?.add(this) ?: throw IllegalStateException("The tokenisable services list has already been finalised") tokenizableServices?.add(this)
?: throw IllegalStateException("The tokenisable services list has already been finalised")
return this return this
} }
@ -607,91 +607,27 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} }
private fun registerCordappFlows() { private fun registerCordappFlows() {
cordappLoader.cordapps.flatMap { it.initiatedFlows } cordappLoader.cordapps.forEach { cordapp ->
.forEach { cordapp.initiatedFlows.groupBy { it.requireAnnotation<InitiatedBy>().value.java }.forEach { initiator, responders ->
responders.forEach { responder ->
try { try {
registerInitiatedFlowInternal(smm, it, track = false) flowManager.registerInitiatedFlow(initiator, responder)
} catch (e: NoSuchMethodException) { } catch (e: NoSuchMethodException) {
log.error("${it.name}, as an initiated flow, must have a constructor with a single parameter " + log.error("${responder.name}, as an initiated flow, must have a constructor with a single parameter " +
"of type ${Party::class.java.name}") "of type ${Party::class.java.name}")
} catch (e: Exception) { throw e
log.error("Unable to register initiated flow ${it.name}", e)
} }
} }
}
fun <T : FlowLogic<*>> registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class<T>): Observable<T> {
return registerInitiatedFlowInternal(smm, initiatedFlowClass, track = true)
}
// TODO remove once not needed
private fun deprecatedFlowConstructorMessage(flowClass: Class<*>): String {
return "Installing flow factory for $flowClass accepting a ${Party::class.java.simpleName}, which is deprecated. " +
"It should accept a ${FlowSession::class.java.simpleName} instead"
}
private fun <F : FlowLogic<*>> registerInitiatedFlowInternal(smm: StateMachineManager, initiatedFlow: Class<F>, track: Boolean): Observable<F> {
val constructors = initiatedFlow.declaredConstructors.associateBy { it.parameterTypes.toList() }
val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true }
val ctor: (FlowSession) -> F = if (flowSessionCtor == null) {
// Try to fallback to a Party constructor
val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true }
if (partyCtor == null) {
throw IllegalArgumentException("$initiatedFlow must have a constructor accepting a ${FlowSession::class.java.name}")
} else {
log.warn(deprecatedFlowConstructorMessage(initiatedFlow))
} }
{ flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) }
} else {
{ flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) }
} }
val initiatingFlow = initiatedFlow.requireAnnotation<InitiatedBy>().value.java flowManager.validateRegistrations()
val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass
require(classWithAnnotation == initiatingFlow) {
"${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiatingFlow.name}"
}
val flowFactory = InitiatedFlowFactory.CorDapp(version, initiatedFlow.appName, ctor)
val observable = internalRegisterFlowFactory(smm, initiatingFlow, flowFactory, initiatedFlow, track)
log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)")
return observable
}
protected fun <F : FlowLogic<*>> internalRegisterFlowFactory(smm: StateMachineManager,
initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>,
track: Boolean): Observable<F> {
val observable = if (track) {
smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
} else {
Observable.empty()
}
check(initiatingFlowClass !in flowFactories.keys) {
"$initiatingFlowClass is attempting to register multiple initiated flows"
}
flowFactories[initiatingFlowClass] = flowFactory
return observable
}
/**
* Installs a flow that's core to the Corda platform. Unlike CorDapp flows which are versioned individually using
* [InitiatingFlow.version], core flows have the same version as the node's platform version. To cater for backwards
* compatibility [flowFactory] provides a second parameter which is the platform version of the initiating party.
*/
@VisibleForTesting
fun installCoreFlow(clientFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>) {
require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) {
"${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version"
}
flowFactories[clientFlowClass.java] = InitiatedFlowFactory.Core(flowFactory)
log.debug { "Installed core flow ${clientFlowClass.java.name}" }
} }
private fun installCoreFlows() { private fun installCoreFlows() {
installCoreFlow(FinalityFlow::class, ::FinalityHandler) flowManager.registerInitiatedCoreFlowFactory(FinalityFlow::class, FinalityHandler::class, ::FinalityHandler)
installCoreFlow(NotaryChangeFlow::class, ::NotaryChangeHandler) flowManager.registerInitiatedCoreFlowFactory(NotaryChangeFlow::class, NotaryChangeHandler::class, ::NotaryChangeHandler)
installCoreFlow(ContractUpgradeFlow.Initiate::class, ::ContractUpgradeHandler) flowManager.registerInitiatedCoreFlowFactory(ContractUpgradeFlow.Initiate::class, NotaryChangeHandler::class, ::ContractUpgradeHandler)
installCoreFlow(SwapIdentitiesFlow::class, ::SwapIdentitiesHandler) flowManager.registerInitiatedCoreFlowFactory(SwapIdentitiesFlow::class, SwapIdentitiesHandler::class, ::SwapIdentitiesHandler)
} }
protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage { protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage {
@ -781,7 +717,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
service.run { service.run {
tokenize() tokenize()
runOnStop += ::stop runOnStop += ::stop
installCoreFlow(NotaryFlow.Client::class, ::createServiceFlow) flowManager.registerInitiatedCoreFlowFactory(NotaryFlow.Client::class, ::createServiceFlow)
start() start()
} }
return service return service
@ -961,7 +897,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} }
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? { override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatingFlowClass] return flowManager.getFlowFactoryForInitiatingFlow(initiatingFlowClass)
} }
override fun jdbcSession(): Connection = database.createSession() override fun jdbcSession(): Connection = database.createSession()
@ -1066,4 +1002,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
} }
// Here we're using the node's RPC key store as the RPC client's trust store. // Here we're using the node's RPC key store as the RPC client's trust store.
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword) return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
} }

View File

@ -0,0 +1,222 @@
package net.corda.node.internal
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.internal.uncheckedCast
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.classloading.requireAnnotation
import net.corda.node.services.config.FlowOverrideConfig
import net.corda.node.services.statemachine.appName
import net.corda.node.services.statemachine.flowVersionAndInitiatingClass
import javax.annotation.concurrent.ThreadSafe
import kotlin.reflect.KClass
/**
*
* This class is responsible for organising which flow should respond to a specific @InitiatingFlow
*
* There are two main ways to modify the behaviour of a cordapp with regards to responding with a different flow
*
* 1.) implementing a new subclass. For example, if we have a ResponderFlow similar to @InitiatedBy(Sender) MyBaseResponder : FlowLogic
* If we subclassed a new Flow with specific logic for DB2, it would be similar to IBMB2Responder() : MyBaseResponder
* When these two flows are encountered by the classpath scan for @InitiatedBy, they will both be selected for responding to Sender
* This implementation will sort them for responding in order of their "depth" from FlowLogic - see: FlowWeightComparator
* So IBMB2Responder would win and it would be selected for responding
*
* 2.) It is possible to specify a flowOverride key in the node configuration. Say we configure a node to have
* flowOverrides{
* "Sender" = "MyBaseResponder"
* }
* In this case, FlowWeightComparator would detect that there is an override in action, and it will assign MyBaseResponder a maximum weight
* This will result in MyBaseResponder being selected for responding to Sender
*
*
*/
interface FlowManager {
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>)
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: (FlowSession) -> FlowLogic<*>)
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: InitiatedFlowFactory.Core<FlowLogic<*>>)
fun <F : FlowLogic<*>> registerInitiatedFlow(initiator: Class<out FlowLogic<*>>, responder: Class<F>)
fun <F : FlowLogic<*>> registerInitiatedFlow(responder: Class<F>)
fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>?
fun validateRegistrations()
}
@ThreadSafe
open class NodeFlowManager(flowOverrides: FlowOverrideConfig? = null) : FlowManager {
private val flowFactories = HashMap<Class<out FlowLogic<*>>, MutableList<RegisteredFlowContainer>>()
private val flowOverrides = (flowOverrides
?: FlowOverrideConfig()).overrides.map { it.initiator to it.responder }.toMutableMap()
companion object {
private val log = contextLogger()
}
@Synchronized
override fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatedFlowClass]?.firstOrNull()?.flowFactory
}
@Synchronized
override fun <F : FlowLogic<*>> registerInitiatedFlow(responder: Class<F>) {
return registerInitiatedFlow(responder.requireAnnotation<InitiatedBy>().value.java, responder)
}
@Synchronized
override fun <F : FlowLogic<*>> registerInitiatedFlow(initiator: Class<out FlowLogic<*>>, responder: Class<F>) {
val constructors = responder.declaredConstructors.associateBy { it.parameterTypes.toList() }
val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true }
val ctor: (FlowSession) -> F = if (flowSessionCtor == null) {
// Try to fallback to a Party constructor
val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true }
if (partyCtor == null) {
throw IllegalArgumentException("$responder must have a constructor accepting a ${FlowSession::class.java.name}")
} else {
log.warn("Installing flow factory for $responder accepting a ${Party::class.java.simpleName}, which is deprecated. " +
"It should accept a ${FlowSession::class.java.simpleName} instead")
}
{ flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) }
} else {
{ flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) }
}
val (version, classWithAnnotation) = initiator.flowVersionAndInitiatingClass
require(classWithAnnotation == initiator) {
"${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiator.name}"
}
val flowFactory = InitiatedFlowFactory.CorDapp(version, responder.appName, ctor)
registerInitiatedFlowFactory(initiator, flowFactory, responder)
log.info("Registered ${initiator.name} to initiate ${responder.name} (version $version)")
}
private fun <F : FlowLogic<*>> registerInitiatedFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>?) {
check(flowFactory !is InitiatedFlowFactory.Core) { "This should only be used for Cordapp flows" }
val listOfFlowsForInitiator = flowFactories.computeIfAbsent(initiatingFlowClass) { mutableListOf() }
if (listOfFlowsForInitiator.isNotEmpty() && listOfFlowsForInitiator.first().type == FlowType.CORE) {
throw IllegalStateException("Attempting to register over an existing platform flow: $initiatingFlowClass")
}
synchronized(listOfFlowsForInitiator) {
val flowToAdd = RegisteredFlowContainer(initiatingFlowClass, initiatedFlowClass, flowFactory, FlowType.CORDAPP)
val flowWeightComparator = FlowWeightComparator(initiatingFlowClass, flowOverrides)
listOfFlowsForInitiator.add(flowToAdd)
listOfFlowsForInitiator.sortWith(flowWeightComparator)
if (listOfFlowsForInitiator.size > 1) {
log.warn("Multiple flows are registered for InitiatingFlow: $initiatingFlowClass, currently using: ${listOfFlowsForInitiator.first().initiatedFlowClass}")
}
}
}
// TODO Harmonise use of these methods - 99% of invocations come from tests.
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: (FlowSession) -> FlowLogic<*>) {
registerInitiatedCoreFlowFactory(initiatingFlowClass, initiatedFlowClass, InitiatedFlowFactory.Core(flowFactory))
}
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>) {
registerInitiatedCoreFlowFactory(initiatingFlowClass, null, InitiatedFlowFactory.Core(flowFactory))
}
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: InitiatedFlowFactory.Core<FlowLogic<*>>) {
require(initiatingFlowClass.java.flowVersionAndInitiatingClass.first == 1) {
"${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version"
}
flowFactories.computeIfAbsent(initiatingFlowClass.java) { mutableListOf() }.add(
RegisteredFlowContainer(
initiatingFlowClass.java,
initiatedFlowClass?.java,
flowFactory,
FlowType.CORE)
)
log.debug { "Installed core flow ${initiatingFlowClass.java.name}" }
}
// To verify the integrity of the current state, it is important that the tip of the responders is a unique weight
// if there are multiple flows with the same weight as the tip, it means that it is impossible to reliably pick one as the responder
private fun validateInvariants(toValidate: List<RegisteredFlowContainer>) {
val currentTip = toValidate.first()
val flowWeightComparator = FlowWeightComparator(currentTip.initiatingFlowClass, flowOverrides)
val equalWeightAsCurrentTip = toValidate.map { flowWeightComparator.compare(currentTip, it) to it }.filter { it.first == 0 }.map { it.second }
if (equalWeightAsCurrentTip.size > 1) {
val message = "Unable to determine which flow to use when responding to: ${currentTip.initiatingFlowClass.canonicalName}. ${equalWeightAsCurrentTip.map { it.initiatedFlowClass!!.canonicalName }} are all registered with equal weight."
throw IllegalStateException(message)
}
}
@Synchronized
override fun validateRegistrations() {
flowFactories.values.forEach {
validateInvariants(it)
}
}
private enum class FlowType {
CORE, CORDAPP
}
private data class RegisteredFlowContainer(val initiatingFlowClass: Class<out FlowLogic<*>>,
val initiatedFlowClass: Class<out FlowLogic<*>>?,
val flowFactory: InitiatedFlowFactory<FlowLogic<*>>,
val type: FlowType)
// this is used to sort the responding flows in order of "importance"
// the logic is as follows
// IF responder is a specific lambda (like for notary implementations / testing code) always return that responder
// ELSE IF responder is present in the overrides list, always return that responder
// ELSE compare responding flows by their depth from FlowLogic, always return the flow which is most specific (IE, has the most hops to FlowLogic)
private open class FlowWeightComparator(val initiatingFlowClass: Class<out FlowLogic<*>>, val flowOverrides: Map<String, String>) : Comparator<NodeFlowManager.RegisteredFlowContainer> {
override fun compare(o1: NodeFlowManager.RegisteredFlowContainer, o2: NodeFlowManager.RegisteredFlowContainer): Int {
if (o1.initiatedFlowClass == null && o2.initiatedFlowClass != null) {
return Int.MIN_VALUE
}
if (o1.initiatedFlowClass != null && o2.initiatedFlowClass == null) {
return Int.MAX_VALUE
}
if (o1.initiatedFlowClass == null && o2.initiatedFlowClass == null) {
return 0
}
val hopsTo1 = calculateHopsToFlowLogic(initiatingFlowClass, o1.initiatedFlowClass!!)
val hopsTo2 = calculateHopsToFlowLogic(initiatingFlowClass, o2.initiatedFlowClass!!)
return hopsTo1.compareTo(hopsTo2) * -1
}
private fun calculateHopsToFlowLogic(initiatingFlowClass: Class<out FlowLogic<*>>,
initiatedFlowClass: Class<out FlowLogic<*>>): Int {
val overriddenClassName = flowOverrides[initiatingFlowClass.canonicalName]
return if (overriddenClassName == initiatedFlowClass.canonicalName) {
Int.MAX_VALUE
} else {
var currentClass: Class<*> = initiatedFlowClass
var count = 0
while (currentClass != FlowLogic::class.java) {
currentClass = currentClass.superclass
count++
}
count;
}
}
}
}
private fun <X, Y> Iterable<Pair<X, Y>>.toMutableMap(): MutableMap<X, Y> {
return this.toMap(HashMap())
}

View File

@ -4,10 +4,12 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession import net.corda.core.flows.FlowSession
sealed class InitiatedFlowFactory<out F : FlowLogic<*>> { sealed class InitiatedFlowFactory<out F : FlowLogic<*>> {
protected abstract val factory: (FlowSession) -> F protected abstract val factory: (FlowSession) -> F
fun createFlow(initiatingFlowSession: FlowSession): F = factory(initiatingFlowSession) fun createFlow(initiatingFlowSession: FlowSession): F = factory(initiatingFlowSession)
data class Core<out F : FlowLogic<*>>(override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>() data class Core<out F : FlowLogic<*>>(override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>()
data class CorDapp<out F : FlowLogic<*>>(val flowVersion: Int, data class CorDapp<out F : FlowLogic<*>>(val flowVersion: Int,
val appName: String, val appName: String,
override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>() override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>()

View File

@ -43,6 +43,7 @@ import net.corda.node.services.api.StartedNodeServices
import net.corda.node.services.config.* import net.corda.node.services.config.*
import net.corda.node.services.messaging.* import net.corda.node.services.messaging.*
import net.corda.node.services.rpc.ArtemisRpcBroker import net.corda.node.services.rpc.ArtemisRpcBroker
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.utilities.* import net.corda.node.utilities.*
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.INTERNAL_SHELL_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.INTERNAL_SHELL_USER
import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.ShutdownHook
@ -56,7 +57,6 @@ import org.apache.commons.lang.SystemUtils
import org.h2.jdbc.JdbcSQLException import org.h2.jdbc.JdbcSQLException
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import rx.Observable
import rx.Scheduler import rx.Scheduler
import rx.schedulers.Schedulers import rx.schedulers.Schedulers
import java.net.BindException import java.net.BindException
@ -72,8 +72,7 @@ import kotlin.system.exitProcess
class NodeWithInfo(val node: Node, val info: NodeInfo) { class NodeWithInfo(val node: Node, val info: NodeInfo) {
val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by node.services, FlowStarter by node.flowStarter {} val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by node.services, FlowStarter by node.flowStarter {}
fun dispose() = node.stop() fun dispose() = node.stop()
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> = fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = node.registerInitiatedFlow(node.smm, initiatedFlowClass)
node.registerInitiatedFlow(node.smm, initiatedFlowClass)
} }
/** /**
@ -85,12 +84,14 @@ class NodeWithInfo(val node: Node, val info: NodeInfo) {
open class Node(configuration: NodeConfiguration, open class Node(configuration: NodeConfiguration,
versionInfo: VersionInfo, versionInfo: VersionInfo,
private val initialiseSerialization: Boolean = true, private val initialiseSerialization: Boolean = true,
flowManager: FlowManager = NodeFlowManager(configuration.flowOverrides),
cacheFactoryPrototype: BindableNamedCacheFactory = DefaultNamedCacheFactory() cacheFactoryPrototype: BindableNamedCacheFactory = DefaultNamedCacheFactory()
) : AbstractNode<NodeInfo>( ) : AbstractNode<NodeInfo>(
configuration, configuration,
createClock(configuration), createClock(configuration),
cacheFactoryPrototype, cacheFactoryPrototype,
versionInfo, versionInfo,
flowManager,
// Under normal (non-test execution) it will always be "1" // Under normal (non-test execution) it will always be "1"
AffinityExecutor.ServiceAffinityExecutor("Node thread-${sameVmNodeCounter.incrementAndGet()}", 1) AffinityExecutor.ServiceAffinityExecutor("Node thread-${sameVmNodeCounter.incrementAndGet()}", 1)
) { ) {
@ -202,7 +203,8 @@ open class Node(configuration: NodeConfiguration,
return P2PMessagingClient( return P2PMessagingClient(
config = configuration, config = configuration,
versionInfo = versionInfo, versionInfo = versionInfo,
serverAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("localhost", configuration.p2pAddress.port), serverAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("localhost", configuration.p2pAddress.port),
nodeExecutor = serverThread, nodeExecutor = serverThread,
database = database, database = database,
networkMap = networkMapCache, networkMap = networkMapCache,
@ -228,7 +230,8 @@ open class Node(configuration: NodeConfiguration,
} }
val messageBroker = if (!configuration.messagingServerExternal) { val messageBroker = if (!configuration.messagingServerExternal) {
val brokerBindAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port) val brokerBindAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port)
ArtemisMessagingServer(configuration, brokerBindAddress, networkParameters.maxMessageSize) ArtemisMessagingServer(configuration, brokerBindAddress, networkParameters.maxMessageSize)
} else { } else {
null null
@ -442,7 +445,7 @@ open class Node(configuration: NodeConfiguration,
}.build().start() }.build().start()
} }
private fun registerNewRelicReporter (registry: MetricRegistry) { private fun registerNewRelicReporter(registry: MetricRegistry) {
log.info("Registering New Relic JMX Reporter:") log.info("Registering New Relic JMX Reporter:")
val reporter = NewRelicReporter.forRegistry(registry) val reporter = NewRelicReporter.forRegistry(registry)
.name("New Relic Reporter") .name("New Relic Reporter")
@ -504,4 +507,8 @@ open class Node(configuration: NodeConfiguration,
log.info("Shutdown complete") log.info("Shutdown complete")
} }
fun <T : FlowLogic<*>> registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class<T>) {
this.flowManager.registerInitiatedFlow(initiatedFlowClass)
}
} }

View File

@ -19,7 +19,6 @@ import net.corda.core.serialization.SerializeAsToken
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.cordapp.CordappLoader import net.corda.node.cordapp.CordappLoader
import net.corda.node.internal.classloading.requireAnnotation
import net.corda.nodeapi.internal.coreContractClasses import net.corda.nodeapi.internal.coreContractClasses
import net.corda.serialization.internal.DefaultWhitelist import net.corda.serialization.internal.DefaultWhitelist
import org.apache.commons.collections4.map.LRUMap import org.apache.commons.collections4.map.LRUMap
@ -148,17 +147,6 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
private fun findInitiatedFlows(scanResult: RestrictedScanResult): List<Class<out FlowLogic<*>>> { private fun findInitiatedFlows(scanResult: RestrictedScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, InitiatedBy::class) return scanResult.getClassesWithAnnotation(FlowLogic::class, InitiatedBy::class)
// First group by the initiating flow class in case there are multiple mappings
.groupBy { it.requireAnnotation<InitiatedBy>().value.java }
.map { (initiatingFlow, initiatedFlows) ->
val sorted = initiatedFlows.sortedWith(FlowTypeHierarchyComparator(initiatingFlow))
if (sorted.size > 1) {
logger.warn("${initiatingFlow.name} has been specified as the inititating flow by multiple flows " +
"in the same type hierarchy: ${sorted.joinToString { it.name }}. Choosing the most " +
"specific sub-type for registration: ${sorted[0].name}.")
}
sorted[0]
}
} }
private fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean { private fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
@ -209,17 +197,7 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
} }
} }
private class FlowTypeHierarchyComparator(val initiatingFlow: Class<out FlowLogic<*>>) : Comparator<Class<out FlowLogic<*>>> {
override fun compare(o1: Class<out FlowLogic<*>>, o2: Class<out FlowLogic<*>>): Int {
return when {
o1 == o2 -> 0
o1.isAssignableFrom(o2) -> 1
o2.isAssignableFrom(o1) -> -1
else -> throw IllegalArgumentException("${initiatingFlow.name} has been specified as the initiating flow by " +
"both ${o1.name} and ${o2.name}")
}
}
}
private fun <T : Any> loadClass(className: String, type: KClass<T>): Class<out T>? { private fun <T : Any> loadClass(className: String, type: KClass<T>): Class<out T>? {
return try { return try {

View File

@ -76,6 +76,7 @@ interface NodeConfiguration {
val p2pSslOptions: MutualSslConfiguration val p2pSslOptions: MutualSslConfiguration
val cordappDirectories: List<Path> val cordappDirectories: List<Path>
val flowOverrides: FlowOverrideConfig?
fun validate(): List<String> fun validate(): List<String>
@ -97,6 +98,9 @@ interface NodeConfiguration {
} }
} }
data class FlowOverrideConfig(val overrides: List<FlowOverride> = listOf())
data class FlowOverride(val initiator: String, val responder: String)
/** /**
* Currently registered JMX Reporters * Currently registered JMX Reporters
*/ */
@ -210,7 +214,8 @@ data class NodeConfigurationImpl(
override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS, override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS,
override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS, override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS,
override val cordappDirectories: List<Path> = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT), override val cordappDirectories: List<Path> = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT),
override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA,
override val flowOverrides: FlowOverrideConfig?
) : NodeConfiguration { ) : NodeConfiguration {
companion object { companion object {
private val logger = loggerFor<NodeConfigurationImpl>() private val logger = loggerFor<NodeConfigurationImpl>()

View File

@ -12,7 +12,6 @@ import net.corda.testing.core.singleIdentity
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNodeParameters import net.corda.testing.node.MockNodeParameters
import net.corda.testing.node.StartedMockNode import net.corda.testing.node.StartedMockNode
import org.assertj.core.api.Assertions.assertThatIllegalStateException
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -39,15 +38,15 @@ class FlowRegistrationTest {
} }
@Test @Test
fun `startup fails when two flows initiated by the same flow are registered`() { fun `succeeds when a subclass of a flow initiated by the same flow is registered`() {
// register the same flow twice to invoke the error without causing errors in other tests // register the same flow twice to invoke the error without causing errors in other tests
responder.registerInitiatedFlow(Responder::class.java) responder.registerInitiatedFlow(Responder1::class.java)
assertThatIllegalStateException().isThrownBy { responder.registerInitiatedFlow(Responder::class.java) } responder.registerInitiatedFlow(Responder1Subclassed::class.java)
} }
@Test @Test
fun `a single initiated flow can be registered without error`() { fun `a single initiated flow can be registered without error`() {
responder.registerInitiatedFlow(Responder::class.java) responder.registerInitiatedFlow(Responder1::class.java)
val result = initiator.startFlow(Initiator(responder.info.singleIdentity())) val result = initiator.startFlow(Initiator(responder.info.singleIdentity()))
mockNetwork.runNetwork() mockNetwork.runNetwork()
assertNotNull(result.get()) assertNotNull(result.get())
@ -63,7 +62,38 @@ class Initiator(val party: Party) : FlowLogic<String>() {
} }
@InitiatedBy(Initiator::class) @InitiatedBy(Initiator::class)
private class Responder(val session: FlowSession) : FlowLogic<Unit>() { private open class Responder1(val session: FlowSession) : FlowLogic<Unit>() {
open fun getPayload(): String {
return "whats up"
}
@Suspendable
override fun call() {
session.receive<String>().unwrap { it }
session.send("What's up")
}
}
@InitiatedBy(Initiator::class)
private open class Responder2(val session: FlowSession) : FlowLogic<Unit>() {
open fun getPayload(): String {
return "whats up"
}
@Suspendable
override fun call() {
session.receive<String>().unwrap { it }
session.send("What's up")
}
}
@InitiatedBy(Initiator::class)
private class Responder1Subclassed(session: FlowSession) : Responder1(session) {
override fun getPayload(): String {
return "im subclassed! that's what's up!"
}
@Suspendable @Suspendable
override fun call() { override fun call() {
session.receive<String>().unwrap { it } session.receive<String>().unwrap { it }

View File

@ -0,0 +1,110 @@
package net.corda.node.internal
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.node.services.config.FlowOverride
import net.corda.node.services.config.FlowOverrideConfig
import org.hamcrest.CoreMatchers.`is`
import org.hamcrest.CoreMatchers.instanceOf
import org.junit.Assert
import org.junit.Test
import org.mockito.Mockito
import java.lang.IllegalStateException
private val marker = "This is a special marker"
class NodeFlowManagerTest {
@InitiatingFlow
class Init : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
open class Resp(val otherSesh: FlowSession) : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
class Resp2(val otherSesh: FlowSession) : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
open class RespSub(sesh: FlowSession) : Resp(sesh) {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
class RespSubSub(sesh: FlowSession) : RespSub(sesh) {
override fun call() {
TODO("not implemented")
}
}
@Test(expected = IllegalStateException::class)
fun `should fail to validate if more than one registration with equal weight`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java)
nodeFlowManager.validateRegistrations()
}
@Test()
fun `should allow registration of flows with different weights`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java)
nodeFlowManager.validateRegistrations()
val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
val flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java)))
}
@Test()
fun `should allow updating of registered responder at runtime`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java)
nodeFlowManager.validateRegistrations()
var factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
var flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java)))
// update
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java)
nodeFlowManager.validateRegistrations()
factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSubSub::class.java)))
}
@Test
fun `should allow an override to be specified`() {
val nodeFlowManager = NodeFlowManager(FlowOverrideConfig(listOf(FlowOverride(Init::class.qualifiedName!!, Resp::class.qualifiedName!!))))
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java)
nodeFlowManager.validateRegistrations()
val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
val flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(Resp::class.java)))
}
}

View File

@ -149,7 +149,7 @@ class NodeTest {
} }
} }
private fun createConfig(nodeName: CordaX500Name): NodeConfiguration { private fun createConfig(nodeName: CordaX500Name): NodeConfigurationImpl {
val fakeAddress = NetworkHostAndPort("0.1.2.3", 456) val fakeAddress = NetworkHostAndPort("0.1.2.3", 456)
return NodeConfigurationImpl( return NodeConfigurationImpl(
baseDirectory = temporaryFolder.root.toPath(), baseDirectory = temporaryFolder.root.toPath(),
@ -167,7 +167,8 @@ class NodeTest {
flowTimeout = FlowTimeoutConfiguration(timeout = Duration.ZERO, backoffBase = 1.0, maxRestartCount = 1), flowTimeout = FlowTimeoutConfiguration(timeout = Duration.ZERO, backoffBase = 1.0, maxRestartCount = 1),
rpcSettings = NodeRpcSettings(address = fakeAddress, adminAddress = null, ssl = null), rpcSettings = NodeRpcSettings(address = fakeAddress, adminAddress = null, ssl = null),
messagingServerAddress = null, messagingServerAddress = null,
notary = null notary = null,
flowOverrides = FlowOverrideConfig(listOf())
) )
} }

View File

@ -56,7 +56,7 @@ class JarScanningCordappLoaderTest {
val actualCordapp = loader.cordapps.single() val actualCordapp = loader.cordapps.single()
assertThat(actualCordapp.contractClassNames).isEqualTo(listOf(isolatedContractId)) assertThat(actualCordapp.contractClassNames).isEqualTo(listOf(isolatedContractId))
assertThat(actualCordapp.initiatedFlows.single().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor") assertThat(actualCordapp.initiatedFlows.first().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor")
assertThat(actualCordapp.rpcFlows).isEmpty() assertThat(actualCordapp.rpcFlows).isEmpty()
assertThat(actualCordapp.schedulableFlows).isEmpty() assertThat(actualCordapp.schedulableFlows).isEmpty()
assertThat(actualCordapp.services).isEmpty() assertThat(actualCordapp.services).isEmpty()
@ -74,7 +74,7 @@ class JarScanningCordappLoaderTest {
assertThat(loader.cordapps).isNotEmpty assertThat(loader.cordapps).isNotEmpty
val actualCordapp = loader.cordapps.single { !it.initiatedFlows.isEmpty() } val actualCordapp = loader.cordapps.single { !it.initiatedFlows.isEmpty() }
assertThat(actualCordapp.initiatedFlows).first().hasSameClassAs(DummyFlow::class.java) assertThat(actualCordapp.initiatedFlows.first()).hasSameClassAs(DummyFlow::class.java)
assertThat(actualCordapp.rpcFlows).first().hasSameClassAs(DummyRPCFlow::class.java) assertThat(actualCordapp.rpcFlows).first().hasSameClassAs(DummyRPCFlow::class.java)
assertThat(actualCordapp.schedulableFlows).first().hasSameClassAs(DummySchedulableFlow::class.java) assertThat(actualCordapp.schedulableFlows).first().hasSameClassAs(DummySchedulableFlow::class.java)
} }

View File

@ -172,8 +172,8 @@ class NodeConfigurationImplTest {
val errors = configuration.validate() val errors = configuration.validate()
assertThat(errors).hasOnlyOneElementSatisfying { assertThat(errors).hasOnlyOneElementSatisfying { error ->
error -> error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously") error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously")
} }
} }
@ -268,7 +268,8 @@ class NodeConfigurationImplTest {
noLocalShell = false, noLocalShell = false,
rpcSettings = rpcSettings, rpcSettings = rpcSettings,
crlCheckSoftFail = true, crlCheckSoftFail = true,
tlsCertCrlDistPoint = null tlsCertCrlDistPoint = null,
flowOverrides = FlowOverrideConfig(listOf())
) )
} }
} }

View File

@ -26,7 +26,6 @@ import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.ProgressTracker.Change import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.persistence.checkpoints import net.corda.node.services.persistence.checkpoints
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState import net.corda.testing.contracts.DummyState
@ -116,7 +115,7 @@ class FlowFrameworkTests {
@Test @Test
fun `exception while fiber suspended`() { fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
val flow = ReceiveFlow(bob) val flow = ReceiveFlow(bob)
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception // Before the flow runs change the suspend action to throw an exception
@ -134,7 +133,7 @@ class FlowFrameworkTests {
@Test @Test
fun `both sides do a send as their first IO request`() { fun `both sides do a send as their first IO request`() {
bobNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) } bobNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) }
aliceNode.services.startFlow(PingPongFlow(bob, 10L)) aliceNode.services.startFlow(PingPongFlow(bob, 10L))
mockNet.runNetwork() mockNet.runNetwork()
@ -151,7 +150,7 @@ class FlowFrameworkTests {
@Test @Test
fun `other side ends before doing expected send`() { fun `other side ends before doing expected send`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { NoOpFlow() } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
@ -161,7 +160,7 @@ class FlowFrameworkTests {
@Test @Test
fun `receiving unexpected session end before entering sendAndReceive`() { fun `receiving unexpected session end before entering sendAndReceive`() {
bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } bobNode.registerCordappFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0) val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter { receivedSessionMessagesObservable().filter {
it.message is ExistingSessionMessage && it.message.payload === EndSessionMessage it.message is ExistingSessionMessage && it.message.payload === EndSessionMessage
@ -176,7 +175,7 @@ class FlowFrameworkTests {
@Test @Test
fun `FlowException thrown on other side`() { fun `FlowException thrown on other side`() {
val erroringFlow = bobNode.registerFlowFactory(ReceiveFlow::class) { val erroringFlow = bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { MyFlowException("Nothing useful") } ExceptionFlow { MyFlowException("Nothing useful") }
} }
val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps } val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps }
@ -240,7 +239,7 @@ class FlowFrameworkTests {
} }
} }
bobNode.registerFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") } bobNode.registerCordappFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") }
val resultFuture = aliceNode.services.startFlow(RetryOnExceptionFlow(bob)).resultFuture val resultFuture = aliceNode.services.startFlow(RetryOnExceptionFlow(bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(resultFuture.getOrThrow()).isEqualTo("Hello") assertThat(resultFuture.getOrThrow()).isEqualTo("Hello")
@ -248,7 +247,7 @@ class FlowFrameworkTests {
@Test @Test
fun `serialisation issue in counterparty`() { fun `serialisation issue in counterparty`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) }
val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
@ -258,7 +257,7 @@ class FlowFrameworkTests {
@Test @Test
fun `FlowException has non-serialisable object`() { fun `FlowException has non-serialisable object`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { NonSerialisableFlowException(NonSerialisableData(1)) } ExceptionFlow { NonSerialisableFlowException(NonSerialisableData(1)) }
} }
val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
@ -275,7 +274,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand(alice.owningKey)) .addCommand(dummyCommand(alice.owningKey))
val stx = aliceNode.services.signInitialTransaction(ptx) val stx = aliceNode.services.signInitialTransaction(ptx)
val committerFiber = aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) { val committerFiber = aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) WaitingFlows.Committer(it)
}.map { it.stateMachine }.map { uncheckedCast<FlowStateMachine<*>, FlowStateMachine<Any>>(it) } }.map { it.stateMachine }.map { uncheckedCast<FlowStateMachine<*>, FlowStateMachine<Any>>(it) }
val waiterStx = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture val waiterStx = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture
@ -290,7 +289,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand()) .addCommand(dummyCommand())
val stx = aliceNode.services.signInitialTransaction(ptx) val stx = aliceNode.services.signInitialTransaction(ptx)
aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) { aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) { throw Exception("Error") } WaitingFlows.Committer(it) { throw Exception("Error") }
} }
val waiter = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture val waiter = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture
@ -307,7 +306,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand(alice.owningKey)) .addCommand(dummyCommand(alice.owningKey))
val stx = aliceNode.services.signInitialTransaction(ptx) val stx = aliceNode.services.signInitialTransaction(ptx)
aliceNode.registerFlowFactory(VaultQueryFlow::class) { aliceNode.registerCordappFlowFactory(VaultQueryFlow::class) {
WaitingFlows.Committer(it) WaitingFlows.Committer(it)
} }
val result = bobNode.services.startFlow(VaultQueryFlow(stx, alice)).resultFuture val result = bobNode.services.startFlow(VaultQueryFlow(stx, alice)).resultFuture
@ -318,7 +317,7 @@ class FlowFrameworkTests {
@Test @Test
fun `customised client flow`() { fun `customised client flow`() {
val receiveFlowFuture = bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) } val receiveFlowFuture = bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) }
aliceNode.services.startFlow(CustomSendFlow("Hello", bob)).resultFuture aliceNode.services.startFlow(CustomSendFlow("Hello", bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello") assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello")
@ -333,7 +332,7 @@ class FlowFrameworkTests {
@Test @Test
fun `upgraded initiating flow`() { fun `upgraded initiating flow`() {
bobNode.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) } bobNode.registerCordappFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) }
val result = aliceNode.services.startFlow(UpgradedFlow(bob)).resultFuture val result = aliceNode.services.startFlow(UpgradedFlow(bob)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receivedSessionMessages).startsWith( assertThat(receivedSessionMessages).startsWith(
@ -347,7 +346,7 @@ class FlowFrameworkTests {
@Test @Test
fun `upgraded initiated flow`() { fun `upgraded initiated flow`() {
bobNode.registerFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) } bobNode.registerCordappFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) }
val initiatingFlow = SendFlow("Old initiating", bob) val initiatingFlow = SendFlow("Old initiating", bob)
val flowInfo = aliceNode.services.startFlow(initiatingFlow).resultFuture val flowInfo = aliceNode.services.startFlow(initiatingFlow).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
@ -387,7 +386,7 @@ class FlowFrameworkTests {
@Test @Test
fun `single inlined sub-flow`() { fun `single inlined sub-flow`() {
bobNode.registerFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) } bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) }
val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello") assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -395,7 +394,7 @@ class FlowFrameworkTests {
@Test @Test
fun `double inlined sub-flow`() { fun `double inlined sub-flow`() {
bobNode.registerFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) } bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) }
val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello") assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -403,7 +402,7 @@ class FlowFrameworkTests {
@Test @Test
fun `non-FlowException thrown on other side`() { fun `non-FlowException thrown on other side`() {
val erroringFlowFuture = bobNode.registerFlowFactory(ReceiveFlow::class) { val erroringFlowFuture = bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { Exception("evil bug!") } ExceptionFlow { Exception("evil bug!") }
} }
val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps } val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps }
@ -507,8 +506,8 @@ class FlowFrameworkTripartyTests {
@Test @Test
fun `sending to multiple parties`() { fun `sending to multiple parties`() {
bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
charlieNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
val payload = "Hello World" val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork() mockNet.runNetwork()
@ -538,8 +537,8 @@ class FlowFrameworkTripartyTests {
fun `receiving from multiple parties`() { fun `receiving from multiple parties`() {
val bobPayload = "Test 1" val bobPayload = "Test 1"
val charliePayload = "Test 2" val charliePayload = "Test 2"
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
charlieNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
aliceNode.services.startFlow(multiReceiveFlow) aliceNode.services.startFlow(multiReceiveFlow)
aliceNode.internals.acceptableLiveFiberCountOnStop = 1 aliceNode.internals.acceptableLiveFiberCountOnStop = 1
@ -564,8 +563,8 @@ class FlowFrameworkTripartyTests {
@Test @Test
fun `FlowException only propagated to parent`() { fun `FlowException only propagated to parent`() {
charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork() mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java) assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
@ -577,9 +576,9 @@ class FlowFrameworkTripartyTests {
// Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
// onto Charlie which will throw the exception // onto Charlie which will throw the exception
val node2Fiber = bobNode val node2Fiber = bobNode
.registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.map { it.stateMachine } .map { it.stateMachine }
charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
mockNet.runNetwork() mockNet.runNetwork()
@ -630,6 +629,8 @@ class FlowFrameworkPersistenceTests {
private lateinit var notaryIdentity: Party private lateinit var notaryIdentity: Party
private lateinit var alice: Party private lateinit var alice: Party
private lateinit var bob: Party private lateinit var bob: Party
private lateinit var aliceFlowManager: MockNodeFlowManager
private lateinit var bobFlowManager: MockNodeFlowManager
@Before @Before
fun start() { fun start() {
@ -637,8 +638,11 @@ class FlowFrameworkPersistenceTests {
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = RoundRobin() servicePeerAllocationStrategy = RoundRobin()
) )
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) aliceFlowManager = MockNodeFlowManager()
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) bobFlowManager = MockNodeFlowManager()
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
@ -664,7 +668,7 @@ class FlowFrameworkPersistenceTests {
@Test @Test
fun `flow restarted just after receiving payload`() { fun `flow restarted just after receiving payload`() {
bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
aliceNode.services.startFlow(SendFlow("Hello", bob)) aliceNode.services.startFlow(SendFlow("Hello", bob))
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
@ -679,7 +683,7 @@ class FlowFrameworkPersistenceTests {
@Test @Test
fun `flow loaded from checkpoint will respond to messages from before start`() { fun `flow loaded from checkpoint will respond to messages from before start`() {
aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>() val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
@ -694,7 +698,7 @@ class FlowFrameworkPersistenceTests {
var sentCount = 0 var sentCount = 0
mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
val secondFlow = charlieNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
mockNet.runNetwork() mockNet.runNetwork()
val charlie = charlieNode.info.singleIdentity() val charlie = charlieNode.info.singleIdentity()
@ -802,23 +806,14 @@ private infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, Sessi
private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean get() = val isPayloadTransfer: Boolean
message is ExistingSessionMessage && message.payload is DataSessionMessage || get() =
message is InitialSessionMessage && message.firstPayload != null message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to" override fun toString(): String = "$from sent $message to $to"
} }
private inline fun <reified P : FlowLogic<*>> TestStartedNode.registerFlowFactory(
initiatingFlowClass: KClass<out FlowLogic<*>>,
initiatedFlowVersion: Int = 1,
noinline flowFactory: (FlowSession) -> P): CordaFuture<P> {
val observable = registerFlowFactory(
initiatingFlowClass.java,
InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory),
P::class.java,
track = true)
return observable.toFuture()
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
@ -1061,7 +1056,8 @@ private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val ot
constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession) constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession)
@Suspendable @Suspendable
override fun call(): Any = (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it } override fun call(): Any = (otherPartySession
?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
} }
private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() { private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@ -1098,4 +1094,4 @@ private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<N
exceptionThrown = exception() exceptionThrown = exception()
throw exceptionThrown throw exceptionThrown
} }
} }

View File

@ -3,10 +3,7 @@ package net.corda.node.services.vault
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.nhaarman.mockito_kotlin.* import com.nhaarman.mockito_kotlin.*
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.flows.FinalityFlow import net.corda.core.flows.*
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.packageName import net.corda.core.internal.packageName
@ -23,14 +20,12 @@ import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.api.VaultServiceInternal
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.internal.cordappsForPackages
import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.cordappsForPackages
import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.startFlow
import org.junit.After import org.junit.After
import org.junit.Test import org.junit.Test
@ -68,7 +63,7 @@ class NodePair(private val mockNet: InternalMockNetwork) {
private set private set
fun <T> communicate(clientLogic: AbstractClientLogic<T>, rebootClient: Boolean): FlowStateMachine<T> { fun <T> communicate(clientLogic: AbstractClientLogic<T>, rebootClient: Boolean): FlowStateMachine<T> {
server.registerFlowFactory(AbstractClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, serverRunning) }, ServerLogic::class.java, false) server.registerCoreFlowFactory(AbstractClientLogic::class.java, ServerLogic::class.java, { ServerLogic(it, serverRunning) }, false)
client.services.startFlow(clientLogic) client.services.startFlow(clientLogic)
while (!serverRunning.get()) mockNet.runNetwork(1) while (!serverRunning.get()) mockNet.runNetwork(1)
if (rebootClient) { if (rebootClient) {

View File

@ -89,6 +89,20 @@ task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar', nodeTask,
} }
extraConfig = ['h2Settings.address' : 'localhost:10017'] extraConfig = ['h2Settings.address' : 'localhost:10017']
} }
//All other nodes should be using LoggingBuyerFlow as it is a subclass of BuyerFlow
node {
name "O=LoggingBank,L=London,C=GB"
p2pPort 10025
cordapps = ["$project.group:finance:$corda_release_version"]
rpcUsers = ext.rpcUsers
rpcSettings {
address "localhost:10026"
adminAddress "localhost:10027"
}
extraConfig = ['h2Settings.address' : 'localhost:10035']
flowOverride("net.corda.traderdemo.flow.SellerFlow", "net.corda.traderdemo.flow.BuyerFlow")
}
} }
task integrationTest(type: Test, dependsOn: []) { task integrationTest(type: Test, dependsOn: []) {

View File

@ -11,20 +11,17 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.getCashBalances
import net.corda.finance.flows.TwoPartyTradeFlow import net.corda.finance.flows.TwoPartyTradeFlow
import net.corda.traderdemo.TransactionGraphSearch
import java.util.* import java.util.*
@InitiatedBy(SellerFlow::class) @InitiatedBy(SellerFlow::class)
class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic<Unit>() { open class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic<SignedTransaction>() {
object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset") object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset")
override val progressTracker: ProgressTracker = ProgressTracker(STARTING_BUY) override val progressTracker: ProgressTracker = ProgressTracker(STARTING_BUY)
@Suspendable @Suspendable
override fun call() { override fun call(): SignedTransaction {
progressTracker.currentStep = STARTING_BUY progressTracker.currentStep = STARTING_BUY
// Receive the offered amount and automatically agree to it (in reality this would be a longer negotiation) // Receive the offered amount and automatically agree to it (in reality this would be a longer negotiation)
@ -43,33 +40,6 @@ class BuyerFlow(private val otherSideSession: FlowSession) : FlowLogic<Unit>() {
println("Purchase complete - we are a happy customer! Final transaction is: " + println("Purchase complete - we are a happy customer! Final transaction is: " +
"\n\n${Emoji.renderIfSupported(tradeTX.tx)}") "\n\n${Emoji.renderIfSupported(tradeTX.tx)}")
logIssuanceAttachment(tradeTX) return tradeTX
logBalance()
}
private fun logBalance() {
val balances = serviceHub.getCashBalances().entries.map { "${it.key.currencyCode} ${it.value}" }
println("Remaining balance: ${balances.joinToString()}")
}
private fun logIssuanceAttachment(tradeTX: SignedTransaction) {
// Find the original CP issuance.
// TODO: This is potentially very expensive, and requires transaction details we may no longer have once
// SGX is enabled. Should be replaced with including the attachment on all transactions involving
// the state.
val search = TransactionGraphSearch(serviceHub.validatedTransactions, listOf(tradeTX.tx),
TransactionGraphSearch.Query(withCommandOfType = CommercialPaper.Commands.Issue::class.java,
followInputsOfType = CommercialPaper.State::class.java))
val cpIssuance = search.call().single()
// Buyer will fetch the attachment from the seller automatically when it resolves the transaction.
cpIssuance.attachments.first().let {
println("""
The issuance of the commercial paper came with an attachment. You can find it in the attachments directory: $it.jar
${Emoji.renderIfSupported(cpIssuance)}""")
}
} }
} }

View File

@ -0,0 +1,48 @@
package net.corda.traderdemo.flow
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.internal.Emoji
import net.corda.core.transactions.SignedTransaction
import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.getCashBalances
import net.corda.traderdemo.TransactionGraphSearch
@InitiatedBy(SellerFlow::class)
class LoggingBuyerFlow(private val otherSideSession: FlowSession) : BuyerFlow(otherSideSession) {
@Suspendable
override fun call(): SignedTransaction {
val tradeTX = super.call()
logIssuanceAttachment(tradeTX)
logBalance()
return tradeTX
}
private fun logBalance() {
val balances = serviceHub.getCashBalances().entries.map { "${it.key.currencyCode} ${it.value}" }
println("Remaining balance: ${balances.joinToString()}")
}
private fun logIssuanceAttachment(tradeTX: SignedTransaction) {
// Find the original CP issuance.
// TODO: This is potentially very expensive, and requires transaction details we may no longer have once
// SGX is enabled. Should be replaced with including the attachment on all transactions involving
// the state.
val search = TransactionGraphSearch(serviceHub.validatedTransactions, listOf(tradeTX.tx),
TransactionGraphSearch.Query(withCommandOfType = CommercialPaper.Commands.Issue::class.java,
followInputsOfType = CommercialPaper.State::class.java))
val cpIssuance = search.call().single()
// Buyer will fetch the attachment from the seller automatically when it resolves the transaction.
cpIssuance.attachments.first().let {
println("""
The issuance of the commercial paper came with an attachment. You can find it in the attachments directory: $it.jar
${Emoji.renderIfSupported(cpIssuance)}""")
}
}
}

View File

@ -148,7 +148,8 @@ data class NodeParameters(
val maximumHeapSize: String = "512m", val maximumHeapSize: String = "512m",
val logLevel: String? = null, val logLevel: String? = null,
val additionalCordapps: Collection<TestCordapp> = emptySet(), val additionalCordapps: Collection<TestCordapp> = emptySet(),
val regenerateCordappsOnStart: Boolean = false val regenerateCordappsOnStart: Boolean = false,
val flowOverrides: Map<Class<out FlowLogic<*>>, Class<out FlowLogic<*>>> = emptyMap()
) { ) {
/** /**
* Helper builder for configuring a [Node] from Java. * Helper builder for configuring a [Node] from Java.

View File

@ -2,6 +2,7 @@ package net.corda.testing.driver
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.map import net.corda.core.internal.concurrent.map
@ -27,13 +28,14 @@ interface DriverDSL {
* Returns the [NotaryHandle] for the single notary on the network. Throws if there are none or more than one. * Returns the [NotaryHandle] for the single notary on the network. Throws if there are none or more than one.
* @see notaryHandles * @see notaryHandles
*/ */
val defaultNotaryHandle: NotaryHandle get() { val defaultNotaryHandle: NotaryHandle
return when (notaryHandles.size) { get() {
0 -> throw IllegalStateException("There are no notaries defined on the network") return when (notaryHandles.size) {
1 -> notaryHandles[0] 0 -> throw IllegalStateException("There are no notaries defined on the network")
else -> throw IllegalStateException("There is more than one notary defined on the network") 1 -> notaryHandles[0]
else -> throw IllegalStateException("There is more than one notary defined on the network")
}
} }
}
/** /**
* Returns the identity of the single notary on the network. Throws if there are none or more than one. * Returns the identity of the single notary on the network. Throws if there are none or more than one.
@ -47,11 +49,12 @@ interface DriverDSL {
* @see defaultNotaryHandle * @see defaultNotaryHandle
* @see notaryHandles * @see notaryHandles
*/ */
val defaultNotaryNode: CordaFuture<NodeHandle> get() { val defaultNotaryNode: CordaFuture<NodeHandle>
return defaultNotaryHandle.nodeHandles.map { get() {
it.singleOrNull() ?: throw IllegalStateException("Default notary is not a single node") return defaultNotaryHandle.nodeHandles.map {
it.singleOrNull() ?: throw IllegalStateException("Default notary is not a single node")
}
} }
}
/** /**
* Start a node. * Start a node.
@ -110,7 +113,8 @@ interface DriverDSL {
startInSameProcess: Boolean? = defaultParameters.startInSameProcess, startInSameProcess: Boolean? = defaultParameters.startInSameProcess,
maximumHeapSize: String = defaultParameters.maximumHeapSize, maximumHeapSize: String = defaultParameters.maximumHeapSize,
additionalCordapps: Collection<TestCordapp> = defaultParameters.additionalCordapps, additionalCordapps: Collection<TestCordapp> = defaultParameters.additionalCordapps,
regenerateCordappsOnStart: Boolean = defaultParameters.regenerateCordappsOnStart regenerateCordappsOnStart: Boolean = defaultParameters.regenerateCordappsOnStart,
flowOverrides: Map<out Class<out FlowLogic<*>>, Class<out FlowLogic<*>>> = defaultParameters.flowOverrides
): CordaFuture<NodeHandle> ): CordaFuture<NodeHandle>
/** /**

View File

@ -14,6 +14,7 @@ import net.corda.testing.driver.OutOfProcess
import net.corda.testing.node.User import net.corda.testing.node.User
import rx.Observable import rx.Observable
import java.nio.file.Path import java.nio.file.Path
import javax.validation.constraints.NotNull
interface NodeHandleInternal : NodeHandle { interface NodeHandleInternal : NodeHandle {
val configuration: NodeConfiguration val configuration: NodeConfiguration
@ -70,7 +71,11 @@ data class InProcessImpl(
} }
override fun close() = stop() override fun close() = stop()
override fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> = node.registerInitiatedFlow(initiatedFlowClass) @NotNull
override fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> {
node.registerInitiatedFlow(initiatedFlowClass)
return Observable.empty()
}
} }
val InProcess.internalServices: StartedNodeServices get() = services as StartedNodeServices val InProcess.internalServices: StartedNodeServices get() = services as StartedNodeServices

View File

@ -206,11 +206,8 @@ class StartedMockNode private constructor(private val node: TestStartedNode) {
fun <F : FlowLogic<*>> registerResponderFlow(initiatingFlowClass: Class<out FlowLogic<*>>, fun <F : FlowLogic<*>> registerResponderFlow(initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: ResponderFlowFactory<F>, flowFactory: ResponderFlowFactory<F>,
responderFlowClass: Class<F>): CordaFuture<F> = responderFlowClass: Class<F>): CordaFuture<F> =
node.registerFlowFactory(
initiatingFlowClass, node.registerInitiatedFlow(initiatingFlowClass, responderFlowClass).toFuture()
InitiatedFlowFactory.CorDapp(flowVersion = 0, appName = "", factory = flowFactory::invoke),
responderFlowClass, true)
.toFuture()
} }
/** /**
@ -240,13 +237,12 @@ interface ResponderFlowFactory<F : FlowLogic<*>> {
*/ */
inline fun <reified F : FlowLogic<*>> StartedMockNode.registerResponderFlow( inline fun <reified F : FlowLogic<*>> StartedMockNode.registerResponderFlow(
initiatingFlowClass: Class<out FlowLogic<*>>, initiatingFlowClass: Class<out FlowLogic<*>>,
noinline flowFactory: (FlowSession) -> F): Future<F> = noinline flowFactory: (FlowSession) -> F): Future<F> = registerResponderFlow(
registerResponderFlow( initiatingFlowClass,
initiatingFlowClass, object : ResponderFlowFactory<F> {
object : ResponderFlowFactory<F> { override fun invoke(flowSession: FlowSession) = flowFactory(flowSession)
override fun invoke(flowSession: FlowSession) = flowFactory(flowSession) },
}, F::class.java)
F::class.java)
/** /**
* A mock node brings up a suite of in-memory services in a fast manner suitable for unit testing. * A mock node brings up a suite of in-memory services in a fast manner suitable for unit testing.
@ -302,13 +298,13 @@ open class MockNetwork(
constructor(cordappPackages: List<String>, parameters: MockNetworkParameters = MockNetworkParameters()) : this(cordappPackages, defaultParameters = parameters) constructor(cordappPackages: List<String>, parameters: MockNetworkParameters = MockNetworkParameters()) : this(cordappPackages, defaultParameters = parameters)
constructor( constructor(
cordappPackages: List<String>, cordappPackages: List<String>,
defaultParameters: MockNetworkParameters = MockNetworkParameters(), defaultParameters: MockNetworkParameters = MockNetworkParameters(),
networkSendManuallyPumped: Boolean = defaultParameters.networkSendManuallyPumped, networkSendManuallyPumped: Boolean = defaultParameters.networkSendManuallyPumped,
threadPerNode: Boolean = defaultParameters.threadPerNode, threadPerNode: Boolean = defaultParameters.threadPerNode,
servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = defaultParameters.servicePeerAllocationStrategy, servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = defaultParameters.servicePeerAllocationStrategy,
notarySpecs: List<MockNetworkNotarySpec> = defaultParameters.notarySpecs, notarySpecs: List<MockNetworkNotarySpec> = defaultParameters.notarySpecs,
networkParameters: NetworkParameters = defaultParameters.networkParameters networkParameters: NetworkParameters = defaultParameters.networkParameters
) : this(emptyList(), defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters, cordappsForAllNodes = cordappsForPackages(cordappPackages)) ) : this(emptyList(), defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters, cordappsForAllNodes = cordappsForPackages(cordappPackages))
private val internalMockNetwork: InternalMockNetwork = InternalMockNetwork(defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters = networkParameters, cordappsForAllNodes = cordappsForAllNodes) private val internalMockNetwork: InternalMockNetwork = InternalMockNetwork(defaultParameters, networkSendManuallyPumped, threadPerNode, servicePeerAllocationStrategy, notarySpecs, networkParameters = networkParameters, cordappsForAllNodes = cordappsForAllNodes)

View File

@ -8,6 +8,7 @@ import com.typesafe.config.ConfigValueFactory
import net.corda.client.rpc.internal.createCordaRPCClientWithSslAndClassLoader import net.corda.client.rpc.internal.createCordaRPCClientWithSslAndClassLoader
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.concurrent.firstOf import net.corda.core.concurrent.firstOf
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.* import net.corda.core.internal.*
import net.corda.core.internal.concurrent.* import net.corda.core.internal.concurrent.*
@ -37,7 +38,6 @@ import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.network.NetworkParametersCopier import net.corda.nodeapi.internal.network.NetworkParametersCopier
import net.corda.nodeapi.internal.network.NodeInfoFilesCopier import net.corda.nodeapi.internal.network.NodeInfoFilesCopier
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.testing.node.TestCordapp
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_BANK_A_NAME import net.corda.testing.core.DUMMY_BANK_A_NAME
@ -50,6 +50,7 @@ import net.corda.testing.internal.setGlobalSerialization
import net.corda.testing.internal.stubs.CertificateStoreStubs import net.corda.testing.internal.stubs.CertificateStoreStubs
import net.corda.testing.node.ClusterSpec import net.corda.testing.node.ClusterSpec
import net.corda.testing.node.NotarySpec import net.corda.testing.node.NotarySpec
import net.corda.testing.node.TestCordapp
import net.corda.testing.node.User import net.corda.testing.node.User
import net.corda.testing.node.internal.DriverDSLImpl.Companion.cordappsInCurrentAndAdditionalPackages import net.corda.testing.node.internal.DriverDSLImpl.Companion.cordappsInCurrentAndAdditionalPackages
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
@ -213,7 +214,8 @@ class DriverDSLImpl(
startInSameProcess: Boolean?, startInSameProcess: Boolean?,
maximumHeapSize: String, maximumHeapSize: String,
additionalCordapps: Collection<TestCordapp>, additionalCordapps: Collection<TestCordapp>,
regenerateCordappsOnStart: Boolean regenerateCordappsOnStart: Boolean,
flowOverrides: Map<out Class<out FlowLogic<*>>, Class<out FlowLogic<*>>>
): CordaFuture<NodeHandle> { ): CordaFuture<NodeHandle> {
val p2pAddress = portAllocation.nextHostAndPort() val p2pAddress = portAllocation.nextHostAndPort()
// TODO: Derive name from the full picked name, don't just wrap the common name // TODO: Derive name from the full picked name, don't just wrap the common name
@ -230,7 +232,7 @@ class DriverDSLImpl(
return registrationFuture.flatMap { return registrationFuture.flatMap {
networkMapAvailability.flatMap { networkMapAvailability.flatMap {
// But starting the node proper does require the network map // But starting the node proper does require the network map
startRegisteredNode(name, it, rpcUsers, verifierType, customOverrides, startInSameProcess, maximumHeapSize, p2pAddress, additionalCordapps, regenerateCordappsOnStart) startRegisteredNode(name, it, rpcUsers, verifierType, customOverrides, startInSameProcess, maximumHeapSize, p2pAddress, additionalCordapps, regenerateCordappsOnStart, flowOverrides)
} }
} }
} }
@ -244,7 +246,8 @@ class DriverDSLImpl(
maximumHeapSize: String = "512m", maximumHeapSize: String = "512m",
p2pAddress: NetworkHostAndPort = portAllocation.nextHostAndPort(), p2pAddress: NetworkHostAndPort = portAllocation.nextHostAndPort(),
additionalCordapps: Collection<TestCordapp> = emptySet(), additionalCordapps: Collection<TestCordapp> = emptySet(),
regenerateCordappsOnStart: Boolean = false): CordaFuture<NodeHandle> { regenerateCordappsOnStart: Boolean = false,
flowOverrides: Map<out Class<out FlowLogic<*>>, Class<out FlowLogic<*>>> = emptyMap()): CordaFuture<NodeHandle> {
val rpcAddress = portAllocation.nextHostAndPort() val rpcAddress = portAllocation.nextHostAndPort()
val rpcAdminAddress = portAllocation.nextHostAndPort() val rpcAdminAddress = portAllocation.nextHostAndPort()
val webAddress = portAllocation.nextHostAndPort() val webAddress = portAllocation.nextHostAndPort()
@ -258,14 +261,16 @@ class DriverDSLImpl(
"networkServices.networkMapURL" to compatibilityZone.networkMapURL().toString()) "networkServices.networkMapURL" to compatibilityZone.networkMapURL().toString())
} }
val flowOverrideConfig = flowOverrides.entries.map { FlowOverride(it.key.canonicalName, it.value.canonicalName) }.let { FlowOverrideConfig(it) }
val overrides = configOf( val overrides = configOf(
"myLegalName" to name.toString(), NodeConfiguration::myLegalName.name to name.toString(),
"p2pAddress" to p2pAddress.toString(), NodeConfiguration::p2pAddress.name to p2pAddress.toString(),
"rpcSettings.address" to rpcAddress.toString(), "rpcSettings.address" to rpcAddress.toString(),
"rpcSettings.adminAddress" to rpcAdminAddress.toString(), "rpcSettings.adminAddress" to rpcAdminAddress.toString(),
"useTestClock" to useTestClock, NodeConfiguration::useTestClock.name to useTestClock,
"rpcUsers" to if (users.isEmpty()) defaultRpcUserList else users.map { it.toConfig().root().unwrapped() }, NodeConfiguration::rpcUsers.name to if (users.isEmpty()) defaultRpcUserList else users.map { it.toConfig().root().unwrapped() },
"verifierType" to verifierType.name NodeConfiguration::verifierType.name to verifierType.name,
NodeConfiguration::flowOverrides.name to flowOverrideConfig.toConfig().root().unwrapped()
) + czUrlConfig + customOverrides ) + czUrlConfig + customOverrides
val config = NodeConfig(ConfigHelper.loadConfig( val config = NodeConfig(ConfigHelper.loadConfig(
baseDirectory = baseDirectory(name), baseDirectory = baseDirectory(name),
@ -516,8 +521,7 @@ class DriverDSLImpl(
localNetworkMap, localNetworkMap,
spec.rpcUsers, spec.rpcUsers,
spec.verifierType, spec.verifierType,
customOverrides = notaryConfig(clusterAddress) customOverrides = notaryConfig(clusterAddress))
)
// All other nodes will join the cluster // All other nodes will join the cluster
val restNodeFutures = nodeNames.drop(1).map { val restNodeFutures = nodeNames.drop(1).map {

View File

@ -30,6 +30,8 @@ import net.corda.core.utilities.seconds
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.AbstractNode import net.corda.node.internal.AbstractNode
import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.NodeFlowManager
import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.api.StartedNodeServices
@ -51,7 +53,6 @@ import net.corda.nodeapi.internal.config.User
import net.corda.nodeapi.internal.network.NetworkParametersCopier import net.corda.nodeapi.internal.network.NetworkParametersCopier
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.node.TestCordapp
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.setGlobalSerialization import net.corda.testing.internal.setGlobalSerialization
@ -79,7 +80,8 @@ data class MockNodeArgs(
val network: InternalMockNetwork, val network: InternalMockNetwork,
val id: Int, val id: Int,
val entropyRoot: BigInteger, val entropyRoot: BigInteger,
val version: VersionInfo = MOCK_VERSION_INFO val version: VersionInfo = MOCK_VERSION_INFO,
val flowManager: MockNodeFlowManager = MockNodeFlowManager()
) )
// TODO We don't need a parameters object as this is internal only // TODO We don't need a parameters object as this is internal only
@ -89,7 +91,8 @@ data class InternalMockNodeParameters(
val entropyRoot: BigInteger = BigInteger.valueOf(random63BitValue()), val entropyRoot: BigInteger = BigInteger.valueOf(random63BitValue()),
val configOverrides: (NodeConfiguration) -> Any? = {}, val configOverrides: (NodeConfiguration) -> Any? = {},
val version: VersionInfo = MOCK_VERSION_INFO, val version: VersionInfo = MOCK_VERSION_INFO,
val additionalCordapps: Collection<TestCordapp>? = null) { val additionalCordapps: Collection<TestCordapp>? = null,
val flowManager: MockNodeFlowManager = MockNodeFlowManager()) {
constructor(mockNodeParameters: MockNodeParameters) : this( constructor(mockNodeParameters: MockNodeParameters) : this(
mockNodeParameters.forcedID, mockNodeParameters.forcedID,
mockNodeParameters.legalName, mockNodeParameters.legalName,
@ -132,12 +135,10 @@ interface TestStartedNode {
* starts up for all [FlowLogic] classes it finds which are annotated with [InitiatedBy]. * starts up for all [FlowLogic] classes it finds which are annotated with [InitiatedBy].
* @return An [Observable] of the initiated flows started by counterparties. * @return An [Observable] of the initiated flows started by counterparties.
*/ */
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>, track: Boolean = false): Observable<T>
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatingFlowClass: Class<out FlowLogic<*>>, initiatedFlowClass: Class<T>, track: Boolean = false): Observable<T>
fun <F : FlowLogic<*>> registerFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>,
track: Boolean): Observable<F>
} }
open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParameters(), open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParameters(),
@ -202,7 +203,8 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
*/ */
val defaultNotaryIdentity: Party val defaultNotaryIdentity: Party
get() { get() {
return defaultNotaryNode.info.legalIdentities.singleOrNull() ?: throw IllegalStateException("Default notary has multiple identities") return defaultNotaryNode.info.legalIdentities.singleOrNull()
?: throw IllegalStateException("Default notary has multiple identities")
} }
/** /**
@ -270,11 +272,12 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
} }
} }
open class MockNode(args: MockNodeArgs) : AbstractNode<TestStartedNode>( open class MockNode(args: MockNodeArgs, private val mockFlowManager: MockNodeFlowManager = args.flowManager) : AbstractNode<TestStartedNode>(
args.config, args.config,
TestClock(Clock.systemUTC()), TestClock(Clock.systemUTC()),
DefaultNamedCacheFactory(), DefaultNamedCacheFactory(),
args.version, args.version,
mockFlowManager,
args.network.getServerThread(args.id), args.network.getServerThread(args.id),
args.network.busyLatch args.network.busyLatch
) { ) {
@ -294,24 +297,28 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
override val rpcOps: CordaRPCOps, override val rpcOps: CordaRPCOps,
override val notaryService: NotaryService?) : TestStartedNode { override val notaryService: NotaryService?) : TestStartedNode {
override fun <F : FlowLogic<*>> registerFlowFactory(
initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>,
track: Boolean): Observable<F> =
internals.internalRegisterFlowFactory(smm, initiatingFlowClass, flowFactory, initiatedFlowClass, track)
override fun dispose() = internals.stop() override fun dispose() = internals.stop()
override fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> = override fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>, track: Boolean): Observable<T> {
internals.registerInitiatedFlow(smm, initiatedFlowClass) internals.flowManager.registerInitiatedFlow(initiatedFlowClass)
return smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
}
override fun <T : FlowLogic<*>> registerInitiatedFlow(initiatingFlowClass: Class<out FlowLogic<*>>, initiatedFlowClass: Class<T>, track: Boolean): Observable<T> {
internals.flowManager.registerInitiatedFlow(initiatingFlowClass, initiatedFlowClass)
return smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
}
} }
val mockNet = args.network val mockNet = args.network
val id = args.id val id = args.id
init { init {
require(id >= 0) { "Node ID must be zero or positive, was passed: $id" } require(id >= 0) { "Node ID must be zero or positive, was passed: $id" }
} }
private val entropyRoot = args.entropyRoot private val entropyRoot = args.entropyRoot
var counter = entropyRoot var counter = entropyRoot
override val log get() = staticLog override val log get() = staticLog
@ -333,7 +340,7 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
this, this,
attachments, attachments,
network as MockNodeMessagingService, network as MockNodeMessagingService,
object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter { }, object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter {},
nodeInfo, nodeInfo,
smm, smm,
database, database,
@ -417,8 +424,19 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
var acceptableLiveFiberCountOnStop: Int = 0 var acceptableLiveFiberCountOnStop: Int = 0
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
fun <T : FlowLogic<*>> registerInitiatedFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>, initiatedFlowClass: Class<T>, factory: InitiatedFlowFactory<T>, track: Boolean): Observable<T> {
mockFlowManager.registerTestingFactory(initiatingFlowClass, factory)
return if (track) {
smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
} else {
Observable.empty<T>()
}
}
} }
fun createUnstartedNode(parameters: InternalMockNodeParameters = InternalMockNodeParameters()): MockNode { fun createUnstartedNode(parameters: InternalMockNodeParameters = InternalMockNodeParameters()): MockNode {
return createUnstartedNode(parameters, defaultFactory) return createUnstartedNode(parameters, defaultFactory)
} }
@ -453,7 +471,7 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
val cordappDirectories = cordapps.map { TestCordappDirectories.getJarDirectory(it) }.distinct() val cordappDirectories = cordapps.map { TestCordappDirectories.getJarDirectory(it) }.distinct()
doReturn(cordappDirectories).whenever(config).cordappDirectories doReturn(cordappDirectories).whenever(config).cordappDirectories
val node = nodeFactory(MockNodeArgs(config, this, id, parameters.entropyRoot, parameters.version)) val node = nodeFactory(MockNodeArgs(config, this, id, parameters.entropyRoot, parameters.version, flowManager = parameters.flowManager))
_nodes += node _nodes += node
if (start) { if (start) {
node.start() node.start()
@ -482,8 +500,10 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
*/ */
@JvmOverloads @JvmOverloads
fun runNetwork(rounds: Int = -1) { fun runNetwork(rounds: Int = -1) {
check(!networkSendManuallyPumped) { "MockNetwork.runNetwork() should only be used when networkSendManuallyPumped == false. " + check(!networkSendManuallyPumped) {
"You can use MockNetwork.waitQuiescent() to wait for all the nodes to process all the messages on their queues instead." } "MockNetwork.runNetwork() should only be used when networkSendManuallyPumped == false. " +
"You can use MockNetwork.waitQuiescent() to wait for all the nodes to process all the messages on their queues instead."
}
fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) } fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) }
if (rounds == -1) { if (rounds == -1) {
@ -572,3 +592,17 @@ private fun mockNodeConfiguration(certificatesDirectory: Path): NodeConfiguratio
doReturn(null).whenever(it).devModeOptions doReturn(null).whenever(it).devModeOptions
} }
} }
class MockNodeFlowManager : NodeFlowManager() {
val testingRegistrations = HashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
override fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
if (initiatedFlowClass in testingRegistrations) {
return testingRegistrations.get(initiatedFlowClass)
}
return super.getFlowFactoryForInitiatingFlow(initiatedFlowClass)
}
fun registerTestingFactory(initiator: Class<out FlowLogic<*>>, factory: InitiatedFlowFactory<*>) {
testingRegistrations.put(initiator, factory)
}
}

View File

@ -10,7 +10,9 @@ import net.corda.core.internal.div
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.FlowManager
import net.corda.node.internal.Node import net.corda.node.internal.Node
import net.corda.node.internal.NodeFlowManager
import net.corda.node.internal.NodeWithInfo import net.corda.node.internal.NodeWithInfo
import net.corda.node.services.config.* import net.corda.node.services.config.*
import net.corda.nodeapi.internal.config.toConfig import net.corda.nodeapi.internal.config.toConfig
@ -87,7 +89,8 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
fun startNode(legalName: CordaX500Name, fun startNode(legalName: CordaX500Name,
platformVersion: Int = PLATFORM_VERSION, platformVersion: Int = PLATFORM_VERSION,
rpcUsers: List<User> = emptyList(), rpcUsers: List<User> = emptyList(),
configOverrides: Map<String, Any> = emptyMap()): NodeWithInfo { configOverrides: Map<String, Any> = emptyMap(),
flowManager: FlowManager = NodeFlowManager(FlowOverrideConfig())): NodeWithInfo {
val baseDirectory = baseDirectory(legalName).createDirectories() val baseDirectory = baseDirectory(legalName).createDirectories()
val p2pAddress = configOverrides["p2pAddress"] ?: portAllocation.nextHostAndPort().toString() val p2pAddress = configOverrides["p2pAddress"] ?: portAllocation.nextHostAndPort().toString()
val config = ConfigHelper.loadConfig( val config = ConfigHelper.loadConfig(
@ -103,7 +106,8 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
) + configOverrides ) + configOverrides
) )
val cordapps = cordappsForPackages(getCallerPackage(NodeBasedTest::class)?.let { cordappPackages + it } ?: cordappPackages) val cordapps = cordappsForPackages(getCallerPackage(NodeBasedTest::class)?.let { cordappPackages + it }
?: cordappPackages)
val existingCorDappDirectoriesOption = if (config.hasPath(NodeConfiguration.cordappDirectoriesKey)) config.getStringList(NodeConfiguration.cordappDirectoriesKey) else emptyList() val existingCorDappDirectoriesOption = if (config.hasPath(NodeConfiguration.cordappDirectoriesKey)) config.getStringList(NodeConfiguration.cordappDirectoriesKey) else emptyList()
@ -119,7 +123,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
} }
defaultNetworkParameters.install(baseDirectory) defaultNetworkParameters.install(baseDirectory)
val node = InProcessNode(parsedConfig, MOCK_VERSION_INFO.copy(platformVersion = platformVersion)) val node = InProcessNode(parsedConfig, MOCK_VERSION_INFO.copy(platformVersion = platformVersion), flowManager = flowManager)
val nodeInfo = node.start() val nodeInfo = node.start()
val nodeWithInfo = NodeWithInfo(node, nodeInfo) val nodeWithInfo = NodeWithInfo(node, nodeInfo)
nodes += nodeWithInfo nodes += nodeWithInfo
@ -145,7 +149,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
} }
} }
class InProcessNode(configuration: NodeConfiguration, versionInfo: VersionInfo) : Node(configuration, versionInfo, false) { class InProcessNode(configuration: NodeConfiguration, versionInfo: VersionInfo, flowManager: FlowManager = NodeFlowManager(configuration.flowOverrides)) : Node(configuration, versionInfo, false, flowManager = flowManager) {
override fun start() : NodeInfo { override fun start() : NodeInfo {
check(isValidJavaVersion()) { "You are using a version of Java that is not supported (${SystemUtils.JAVA_VERSION}). Please upgrade to the latest version of Java 8." } check(isValidJavaVersion()) { "You are using a version of Java that is not supported (${SystemUtils.JAVA_VERSION}). Please upgrade to the latest version of Java 8." }

View File

@ -3,7 +3,6 @@ package net.corda.testing.internal
import net.corda.core.contracts.ContractClassName import net.corda.core.contracts.ContractClassName
import net.corda.core.cordapp.Cordapp import net.corda.core.cordapp.Cordapp
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER
import net.corda.core.internal.cordapp.CordappImpl import net.corda.core.internal.cordapp.CordappImpl
import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.AttachmentId
@ -13,7 +12,6 @@ import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.testing.services.MockAttachmentStorage import net.corda.testing.services.MockAttachmentStorage
import java.nio.file.Paths import java.nio.file.Paths
import java.security.PublicKey import java.security.PublicKey
import java.util.*
class MockCordappProvider( class MockCordappProvider(
cordappLoader: CordappLoader, cordappLoader: CordappLoader,