ENT-2509 - Make @InitiatedBy flows overridable via node config (#3960)

* first attempt at a flowManager

fix test breakages

add testing around registering subclasses

make flowManager a param of MockNode

extract interface
rename methods

more work around overriding flows

more test fixes

add sample project showing how to use flowOverrides

rebase

* make smallest possible changes to AttachmentSerializationTest and ReceiveAllFlowTests

* add some comments about how flow manager weights flows

* address review comments
add documentation

* address more review comments
This commit is contained in:
Stefano Franz
2018-10-23 16:45:07 +01:00
committed by GitHub
parent f8ac35df25
commit 0919b01271
36 changed files with 930 additions and 324 deletions

View File

@ -0,0 +1,85 @@
package net.corda.node.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.testing.core.singleIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.internal.cordappForClasses
import org.hamcrest.CoreMatchers.`is`
import org.junit.Assert
import org.junit.Test
class FlowOverrideTests {
@StartableByRPC
@InitiatingFlow
class Ping(private val pongParty: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String {
val pongSession = initiateFlow(pongParty)
return pongSession.sendAndReceive<String>("PING").unwrap { it }
}
}
@InitiatedBy(Ping::class)
open class Pong(private val pingSession: FlowSession) : FlowLogic<Unit>() {
companion object {
val PONG = "PONG"
}
@Suspendable
override fun call() {
pingSession.send(PONG)
}
}
@InitiatedBy(Ping::class)
class Pong2(private val pingSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
pingSession.send("PONGPONG")
}
}
@InitiatedBy(Ping::class)
class Pongiest(private val pingSession: FlowSession) : Pong(pingSession) {
companion object {
val GORGONZOLA = "Gorgonzola"
}
@Suspendable
override fun call() {
pingSession.send(GORGONZOLA)
}
}
private val nodeAClasses = setOf(Ping::class.java,
Pong::class.java, Pongiest::class.java)
private val nodeBClasses = setOf(Ping::class.java, Pong::class.java)
@Test
fun `should use the most "specific" implementation of a responding flow`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) {
val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray()))).getOrThrow()
val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow()
Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pongiest.GORGONZOLA))
}
}
@Test
fun `should use the overriden implementation of a responding flow`() {
val flowOverrides = mapOf(Ping::class.java to Pong::class.java)
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) {
val nodeA = startNode(additionalCordapps = setOf(cordappForClasses(*nodeAClasses.toTypedArray())), flowOverrides = flowOverrides).getOrThrow()
val nodeB = startNode(additionalCordapps = setOf(cordappForClasses(*nodeBClasses.toTypedArray()))).getOrThrow()
Assert.assertThat(nodeB.rpc.startFlow(::Ping, nodeA.nodeInfo.singleIdentity()).returnValue.getOrThrow(), `is`(net.corda.node.flows.FlowOverrideTests.Pong.PONG))
}
}
}

View File

@ -7,10 +7,11 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.testing.core.singleIdentity
import net.corda.testing.node.internal.NodeBasedTest
import net.corda.node.internal.NodeFlowManager
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.singleIdentity
import net.corda.testing.node.internal.NodeBasedTest
import net.corda.testing.node.internal.startFlow
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
@ -18,9 +19,10 @@ import org.junit.Test
class FlowVersioningTest : NodeBasedTest() {
@Test
fun `getFlowContext returns the platform version for core flows`() {
val bobFlowManager = NodeFlowManager()
val alice = startNode(ALICE_NAME, platformVersion = 2)
val bob = startNode(BOB_NAME, platformVersion = 3)
bob.node.installCoreFlow(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow)
val bob = startNode(BOB_NAME, platformVersion = 3, flowManager = bobFlowManager)
bobFlowManager.registerInitiatedCoreFlowFactory(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow)
val (alicePlatformVersionAccordingToBob, bobPlatformVersionAccordingToAlice) = alice.services.startFlow(
PretendInitiatingCoreFlow(bob.info.singleIdentity())).resultFuture.getOrThrow()
assertThat(alicePlatformVersionAccordingToBob).isEqualTo(2)
@ -45,4 +47,5 @@ class FlowVersioningTest : NodeBasedTest() {
@Suspendable
override fun call() = otherSideSession.send(otherSideSession.getCounterpartyFlowInfo().flowVersion)
}
}

View File

@ -30,7 +30,10 @@ import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.*
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.minutes
import net.corda.node.CordaClock
import net.corda.node.SerialFilter
import net.corda.node.VersionInfo
@ -99,13 +102,10 @@ import java.time.Clock
import java.time.Duration
import java.time.format.DateTimeParseException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS
import kotlin.collections.set
import kotlin.reflect.KClass
import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair
/**
@ -120,11 +120,11 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
val platformClock: CordaClock,
cacheFactoryPrototype: BindableNamedCacheFactory,
protected val versionInfo: VersionInfo,
protected val flowManager: FlowManager,
protected val serverThread: AffinityExecutor.ServiceAffinityExecutor,
private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
protected abstract val log: Logger
@Suppress("LeakingThis")
private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this)
@ -211,7 +211,6 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
).tokenize().closeOnStop()
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val flowFactories = ConcurrentHashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
private val shutdownExecutor = Executors.newSingleThreadExecutor()
protected abstract val transactionVerifierWorkerCount: Int
@ -237,7 +236,8 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
private var _started: S? = null
private fun <T : Any> T.tokenize(): T {
tokenizableServices?.add(this) ?: throw IllegalStateException("The tokenisable services list has already been finalised")
tokenizableServices?.add(this)
?: throw IllegalStateException("The tokenisable services list has already been finalised")
return this
}
@ -607,91 +607,27 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
private fun registerCordappFlows() {
cordappLoader.cordapps.flatMap { it.initiatedFlows }
.forEach {
cordappLoader.cordapps.forEach { cordapp ->
cordapp.initiatedFlows.groupBy { it.requireAnnotation<InitiatedBy>().value.java }.forEach { initiator, responders ->
responders.forEach { responder ->
try {
registerInitiatedFlowInternal(smm, it, track = false)
flowManager.registerInitiatedFlow(initiator, responder)
} catch (e: NoSuchMethodException) {
log.error("${it.name}, as an initiated flow, must have a constructor with a single parameter " +
log.error("${responder.name}, as an initiated flow, must have a constructor with a single parameter " +
"of type ${Party::class.java.name}")
} catch (e: Exception) {
log.error("Unable to register initiated flow ${it.name}", e)
throw e
}
}
}
fun <T : FlowLogic<*>> registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class<T>): Observable<T> {
return registerInitiatedFlowInternal(smm, initiatedFlowClass, track = true)
}
// TODO remove once not needed
private fun deprecatedFlowConstructorMessage(flowClass: Class<*>): String {
return "Installing flow factory for $flowClass accepting a ${Party::class.java.simpleName}, which is deprecated. " +
"It should accept a ${FlowSession::class.java.simpleName} instead"
}
private fun <F : FlowLogic<*>> registerInitiatedFlowInternal(smm: StateMachineManager, initiatedFlow: Class<F>, track: Boolean): Observable<F> {
val constructors = initiatedFlow.declaredConstructors.associateBy { it.parameterTypes.toList() }
val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true }
val ctor: (FlowSession) -> F = if (flowSessionCtor == null) {
// Try to fallback to a Party constructor
val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true }
if (partyCtor == null) {
throw IllegalArgumentException("$initiatedFlow must have a constructor accepting a ${FlowSession::class.java.name}")
} else {
log.warn(deprecatedFlowConstructorMessage(initiatedFlow))
}
{ flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) }
} else {
{ flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) }
}
val initiatingFlow = initiatedFlow.requireAnnotation<InitiatedBy>().value.java
val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass
require(classWithAnnotation == initiatingFlow) {
"${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiatingFlow.name}"
}
val flowFactory = InitiatedFlowFactory.CorDapp(version, initiatedFlow.appName, ctor)
val observable = internalRegisterFlowFactory(smm, initiatingFlow, flowFactory, initiatedFlow, track)
log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)")
return observable
}
protected fun <F : FlowLogic<*>> internalRegisterFlowFactory(smm: StateMachineManager,
initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>,
track: Boolean): Observable<F> {
val observable = if (track) {
smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
} else {
Observable.empty()
}
check(initiatingFlowClass !in flowFactories.keys) {
"$initiatingFlowClass is attempting to register multiple initiated flows"
}
flowFactories[initiatingFlowClass] = flowFactory
return observable
}
/**
* Installs a flow that's core to the Corda platform. Unlike CorDapp flows which are versioned individually using
* [InitiatingFlow.version], core flows have the same version as the node's platform version. To cater for backwards
* compatibility [flowFactory] provides a second parameter which is the platform version of the initiating party.
*/
@VisibleForTesting
fun installCoreFlow(clientFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>) {
require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) {
"${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version"
}
flowFactories[clientFlowClass.java] = InitiatedFlowFactory.Core(flowFactory)
log.debug { "Installed core flow ${clientFlowClass.java.name}" }
flowManager.validateRegistrations()
}
private fun installCoreFlows() {
installCoreFlow(FinalityFlow::class, ::FinalityHandler)
installCoreFlow(NotaryChangeFlow::class, ::NotaryChangeHandler)
installCoreFlow(ContractUpgradeFlow.Initiate::class, ::ContractUpgradeHandler)
installCoreFlow(SwapIdentitiesFlow::class, ::SwapIdentitiesHandler)
flowManager.registerInitiatedCoreFlowFactory(FinalityFlow::class, FinalityHandler::class, ::FinalityHandler)
flowManager.registerInitiatedCoreFlowFactory(NotaryChangeFlow::class, NotaryChangeHandler::class, ::NotaryChangeHandler)
flowManager.registerInitiatedCoreFlowFactory(ContractUpgradeFlow.Initiate::class, NotaryChangeHandler::class, ::ContractUpgradeHandler)
flowManager.registerInitiatedCoreFlowFactory(SwapIdentitiesFlow::class, SwapIdentitiesHandler::class, ::SwapIdentitiesHandler)
}
protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage {
@ -781,7 +717,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
service.run {
tokenize()
runOnStop += ::stop
installCoreFlow(NotaryFlow.Client::class, ::createServiceFlow)
flowManager.registerInitiatedCoreFlowFactory(NotaryFlow.Client::class, ::createServiceFlow)
start()
}
return service
@ -961,7 +897,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatingFlowClass]
return flowManager.getFlowFactoryForInitiatingFlow(initiatingFlowClass)
}
override fun jdbcSession(): Connection = database.createSession()
@ -1066,4 +1002,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
}
// Here we're using the node's RPC key store as the RPC client's trust store.
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
}
}

View File

@ -0,0 +1,222 @@
package net.corda.node.internal
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.internal.uncheckedCast
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.classloading.requireAnnotation
import net.corda.node.services.config.FlowOverrideConfig
import net.corda.node.services.statemachine.appName
import net.corda.node.services.statemachine.flowVersionAndInitiatingClass
import javax.annotation.concurrent.ThreadSafe
import kotlin.reflect.KClass
/**
*
* This class is responsible for organising which flow should respond to a specific @InitiatingFlow
*
* There are two main ways to modify the behaviour of a cordapp with regards to responding with a different flow
*
* 1.) implementing a new subclass. For example, if we have a ResponderFlow similar to @InitiatedBy(Sender) MyBaseResponder : FlowLogic
* If we subclassed a new Flow with specific logic for DB2, it would be similar to IBMB2Responder() : MyBaseResponder
* When these two flows are encountered by the classpath scan for @InitiatedBy, they will both be selected for responding to Sender
* This implementation will sort them for responding in order of their "depth" from FlowLogic - see: FlowWeightComparator
* So IBMB2Responder would win and it would be selected for responding
*
* 2.) It is possible to specify a flowOverride key in the node configuration. Say we configure a node to have
* flowOverrides{
* "Sender" = "MyBaseResponder"
* }
* In this case, FlowWeightComparator would detect that there is an override in action, and it will assign MyBaseResponder a maximum weight
* This will result in MyBaseResponder being selected for responding to Sender
*
*
*/
interface FlowManager {
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>)
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: (FlowSession) -> FlowLogic<*>)
fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: InitiatedFlowFactory.Core<FlowLogic<*>>)
fun <F : FlowLogic<*>> registerInitiatedFlow(initiator: Class<out FlowLogic<*>>, responder: Class<F>)
fun <F : FlowLogic<*>> registerInitiatedFlow(responder: Class<F>)
fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>?
fun validateRegistrations()
}
@ThreadSafe
open class NodeFlowManager(flowOverrides: FlowOverrideConfig? = null) : FlowManager {
private val flowFactories = HashMap<Class<out FlowLogic<*>>, MutableList<RegisteredFlowContainer>>()
private val flowOverrides = (flowOverrides
?: FlowOverrideConfig()).overrides.map { it.initiator to it.responder }.toMutableMap()
companion object {
private val log = contextLogger()
}
@Synchronized
override fun getFlowFactoryForInitiatingFlow(initiatedFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatedFlowClass]?.firstOrNull()?.flowFactory
}
@Synchronized
override fun <F : FlowLogic<*>> registerInitiatedFlow(responder: Class<F>) {
return registerInitiatedFlow(responder.requireAnnotation<InitiatedBy>().value.java, responder)
}
@Synchronized
override fun <F : FlowLogic<*>> registerInitiatedFlow(initiator: Class<out FlowLogic<*>>, responder: Class<F>) {
val constructors = responder.declaredConstructors.associateBy { it.parameterTypes.toList() }
val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true }
val ctor: (FlowSession) -> F = if (flowSessionCtor == null) {
// Try to fallback to a Party constructor
val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true }
if (partyCtor == null) {
throw IllegalArgumentException("$responder must have a constructor accepting a ${FlowSession::class.java.name}")
} else {
log.warn("Installing flow factory for $responder accepting a ${Party::class.java.simpleName}, which is deprecated. " +
"It should accept a ${FlowSession::class.java.simpleName} instead")
}
{ flowSession: FlowSession -> uncheckedCast(partyCtor.newInstance(flowSession.counterparty)) }
} else {
{ flowSession: FlowSession -> uncheckedCast(flowSessionCtor.newInstance(flowSession)) }
}
val (version, classWithAnnotation) = initiator.flowVersionAndInitiatingClass
require(classWithAnnotation == initiator) {
"${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiator.name}"
}
val flowFactory = InitiatedFlowFactory.CorDapp(version, responder.appName, ctor)
registerInitiatedFlowFactory(initiator, flowFactory, responder)
log.info("Registered ${initiator.name} to initiate ${responder.name} (version $version)")
}
private fun <F : FlowLogic<*>> registerInitiatedFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>?) {
check(flowFactory !is InitiatedFlowFactory.Core) { "This should only be used for Cordapp flows" }
val listOfFlowsForInitiator = flowFactories.computeIfAbsent(initiatingFlowClass) { mutableListOf() }
if (listOfFlowsForInitiator.isNotEmpty() && listOfFlowsForInitiator.first().type == FlowType.CORE) {
throw IllegalStateException("Attempting to register over an existing platform flow: $initiatingFlowClass")
}
synchronized(listOfFlowsForInitiator) {
val flowToAdd = RegisteredFlowContainer(initiatingFlowClass, initiatedFlowClass, flowFactory, FlowType.CORDAPP)
val flowWeightComparator = FlowWeightComparator(initiatingFlowClass, flowOverrides)
listOfFlowsForInitiator.add(flowToAdd)
listOfFlowsForInitiator.sortWith(flowWeightComparator)
if (listOfFlowsForInitiator.size > 1) {
log.warn("Multiple flows are registered for InitiatingFlow: $initiatingFlowClass, currently using: ${listOfFlowsForInitiator.first().initiatedFlowClass}")
}
}
}
// TODO Harmonise use of these methods - 99% of invocations come from tests.
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: (FlowSession) -> FlowLogic<*>) {
registerInitiatedCoreFlowFactory(initiatingFlowClass, initiatedFlowClass, InitiatedFlowFactory.Core(flowFactory))
}
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, flowFactory: (FlowSession) -> FlowLogic<*>) {
registerInitiatedCoreFlowFactory(initiatingFlowClass, null, InitiatedFlowFactory.Core(flowFactory))
}
@Synchronized
override fun registerInitiatedCoreFlowFactory(initiatingFlowClass: KClass<out FlowLogic<*>>, initiatedFlowClass: KClass<out FlowLogic<*>>?, flowFactory: InitiatedFlowFactory.Core<FlowLogic<*>>) {
require(initiatingFlowClass.java.flowVersionAndInitiatingClass.first == 1) {
"${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version"
}
flowFactories.computeIfAbsent(initiatingFlowClass.java) { mutableListOf() }.add(
RegisteredFlowContainer(
initiatingFlowClass.java,
initiatedFlowClass?.java,
flowFactory,
FlowType.CORE)
)
log.debug { "Installed core flow ${initiatingFlowClass.java.name}" }
}
// To verify the integrity of the current state, it is important that the tip of the responders is a unique weight
// if there are multiple flows with the same weight as the tip, it means that it is impossible to reliably pick one as the responder
private fun validateInvariants(toValidate: List<RegisteredFlowContainer>) {
val currentTip = toValidate.first()
val flowWeightComparator = FlowWeightComparator(currentTip.initiatingFlowClass, flowOverrides)
val equalWeightAsCurrentTip = toValidate.map { flowWeightComparator.compare(currentTip, it) to it }.filter { it.first == 0 }.map { it.second }
if (equalWeightAsCurrentTip.size > 1) {
val message = "Unable to determine which flow to use when responding to: ${currentTip.initiatingFlowClass.canonicalName}. ${equalWeightAsCurrentTip.map { it.initiatedFlowClass!!.canonicalName }} are all registered with equal weight."
throw IllegalStateException(message)
}
}
@Synchronized
override fun validateRegistrations() {
flowFactories.values.forEach {
validateInvariants(it)
}
}
private enum class FlowType {
CORE, CORDAPP
}
private data class RegisteredFlowContainer(val initiatingFlowClass: Class<out FlowLogic<*>>,
val initiatedFlowClass: Class<out FlowLogic<*>>?,
val flowFactory: InitiatedFlowFactory<FlowLogic<*>>,
val type: FlowType)
// this is used to sort the responding flows in order of "importance"
// the logic is as follows
// IF responder is a specific lambda (like for notary implementations / testing code) always return that responder
// ELSE IF responder is present in the overrides list, always return that responder
// ELSE compare responding flows by their depth from FlowLogic, always return the flow which is most specific (IE, has the most hops to FlowLogic)
private open class FlowWeightComparator(val initiatingFlowClass: Class<out FlowLogic<*>>, val flowOverrides: Map<String, String>) : Comparator<NodeFlowManager.RegisteredFlowContainer> {
override fun compare(o1: NodeFlowManager.RegisteredFlowContainer, o2: NodeFlowManager.RegisteredFlowContainer): Int {
if (o1.initiatedFlowClass == null && o2.initiatedFlowClass != null) {
return Int.MIN_VALUE
}
if (o1.initiatedFlowClass != null && o2.initiatedFlowClass == null) {
return Int.MAX_VALUE
}
if (o1.initiatedFlowClass == null && o2.initiatedFlowClass == null) {
return 0
}
val hopsTo1 = calculateHopsToFlowLogic(initiatingFlowClass, o1.initiatedFlowClass!!)
val hopsTo2 = calculateHopsToFlowLogic(initiatingFlowClass, o2.initiatedFlowClass!!)
return hopsTo1.compareTo(hopsTo2) * -1
}
private fun calculateHopsToFlowLogic(initiatingFlowClass: Class<out FlowLogic<*>>,
initiatedFlowClass: Class<out FlowLogic<*>>): Int {
val overriddenClassName = flowOverrides[initiatingFlowClass.canonicalName]
return if (overriddenClassName == initiatedFlowClass.canonicalName) {
Int.MAX_VALUE
} else {
var currentClass: Class<*> = initiatedFlowClass
var count = 0
while (currentClass != FlowLogic::class.java) {
currentClass = currentClass.superclass
count++
}
count;
}
}
}
}
private fun <X, Y> Iterable<Pair<X, Y>>.toMutableMap(): MutableMap<X, Y> {
return this.toMap(HashMap())
}

View File

@ -4,10 +4,12 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
sealed class InitiatedFlowFactory<out F : FlowLogic<*>> {
protected abstract val factory: (FlowSession) -> F
fun createFlow(initiatingFlowSession: FlowSession): F = factory(initiatingFlowSession)
data class Core<out F : FlowLogic<*>>(override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>()
data class CorDapp<out F : FlowLogic<*>>(val flowVersion: Int,
val appName: String,
override val factory: (FlowSession) -> F) : InitiatedFlowFactory<F>()

View File

@ -43,6 +43,7 @@ import net.corda.node.services.api.StartedNodeServices
import net.corda.node.services.config.*
import net.corda.node.services.messaging.*
import net.corda.node.services.rpc.ArtemisRpcBroker
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.utilities.*
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.INTERNAL_SHELL_USER
import net.corda.nodeapi.internal.ShutdownHook
@ -56,7 +57,6 @@ import org.apache.commons.lang.SystemUtils
import org.h2.jdbc.JdbcSQLException
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import rx.Observable
import rx.Scheduler
import rx.schedulers.Schedulers
import java.net.BindException
@ -72,8 +72,7 @@ import kotlin.system.exitProcess
class NodeWithInfo(val node: Node, val info: NodeInfo) {
val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by node.services, FlowStarter by node.flowStarter {}
fun dispose() = node.stop()
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> =
node.registerInitiatedFlow(node.smm, initiatedFlowClass)
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = node.registerInitiatedFlow(node.smm, initiatedFlowClass)
}
/**
@ -85,12 +84,14 @@ class NodeWithInfo(val node: Node, val info: NodeInfo) {
open class Node(configuration: NodeConfiguration,
versionInfo: VersionInfo,
private val initialiseSerialization: Boolean = true,
flowManager: FlowManager = NodeFlowManager(configuration.flowOverrides),
cacheFactoryPrototype: BindableNamedCacheFactory = DefaultNamedCacheFactory()
) : AbstractNode<NodeInfo>(
configuration,
createClock(configuration),
cacheFactoryPrototype,
versionInfo,
flowManager,
// Under normal (non-test execution) it will always be "1"
AffinityExecutor.ServiceAffinityExecutor("Node thread-${sameVmNodeCounter.incrementAndGet()}", 1)
) {
@ -202,7 +203,8 @@ open class Node(configuration: NodeConfiguration,
return P2PMessagingClient(
config = configuration,
versionInfo = versionInfo,
serverAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("localhost", configuration.p2pAddress.port),
serverAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("localhost", configuration.p2pAddress.port),
nodeExecutor = serverThread,
database = database,
networkMap = networkMapCache,
@ -228,7 +230,8 @@ open class Node(configuration: NodeConfiguration,
}
val messageBroker = if (!configuration.messagingServerExternal) {
val brokerBindAddress = configuration.messagingServerAddress ?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port)
val brokerBindAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("0.0.0.0", configuration.p2pAddress.port)
ArtemisMessagingServer(configuration, brokerBindAddress, networkParameters.maxMessageSize)
} else {
null
@ -442,7 +445,7 @@ open class Node(configuration: NodeConfiguration,
}.build().start()
}
private fun registerNewRelicReporter (registry: MetricRegistry) {
private fun registerNewRelicReporter(registry: MetricRegistry) {
log.info("Registering New Relic JMX Reporter:")
val reporter = NewRelicReporter.forRegistry(registry)
.name("New Relic Reporter")
@ -504,4 +507,8 @@ open class Node(configuration: NodeConfiguration,
log.info("Shutdown complete")
}
fun <T : FlowLogic<*>> registerInitiatedFlow(smm: StateMachineManager, initiatedFlowClass: Class<T>) {
this.flowManager.registerInitiatedFlow(initiatedFlowClass)
}
}

View File

@ -19,7 +19,6 @@ import net.corda.core.serialization.SerializeAsToken
import net.corda.core.utilities.contextLogger
import net.corda.node.VersionInfo
import net.corda.node.cordapp.CordappLoader
import net.corda.node.internal.classloading.requireAnnotation
import net.corda.nodeapi.internal.coreContractClasses
import net.corda.serialization.internal.DefaultWhitelist
import org.apache.commons.collections4.map.LRUMap
@ -148,17 +147,6 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
private fun findInitiatedFlows(scanResult: RestrictedScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, InitiatedBy::class)
// First group by the initiating flow class in case there are multiple mappings
.groupBy { it.requireAnnotation<InitiatedBy>().value.java }
.map { (initiatingFlow, initiatedFlows) ->
val sorted = initiatedFlows.sortedWith(FlowTypeHierarchyComparator(initiatingFlow))
if (sorted.size > 1) {
logger.warn("${initiatingFlow.name} has been specified as the inititating flow by multiple flows " +
"in the same type hierarchy: ${sorted.joinToString { it.name }}. Choosing the most " +
"specific sub-type for registration: ${sorted[0].name}.")
}
sorted[0]
}
}
private fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
@ -209,17 +197,7 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths:
}
}
private class FlowTypeHierarchyComparator(val initiatingFlow: Class<out FlowLogic<*>>) : Comparator<Class<out FlowLogic<*>>> {
override fun compare(o1: Class<out FlowLogic<*>>, o2: Class<out FlowLogic<*>>): Int {
return when {
o1 == o2 -> 0
o1.isAssignableFrom(o2) -> 1
o2.isAssignableFrom(o1) -> -1
else -> throw IllegalArgumentException("${initiatingFlow.name} has been specified as the initiating flow by " +
"both ${o1.name} and ${o2.name}")
}
}
}
private fun <T : Any> loadClass(className: String, type: KClass<T>): Class<out T>? {
return try {

View File

@ -76,6 +76,7 @@ interface NodeConfiguration {
val p2pSslOptions: MutualSslConfiguration
val cordappDirectories: List<Path>
val flowOverrides: FlowOverrideConfig?
fun validate(): List<String>
@ -97,6 +98,9 @@ interface NodeConfiguration {
}
}
data class FlowOverrideConfig(val overrides: List<FlowOverride> = listOf())
data class FlowOverride(val initiator: String, val responder: String)
/**
* Currently registered JMX Reporters
*/
@ -210,7 +214,8 @@ data class NodeConfigurationImpl(
override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS,
override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS,
override val cordappDirectories: List<Path> = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT),
override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA
override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA,
override val flowOverrides: FlowOverrideConfig?
) : NodeConfiguration {
companion object {
private val logger = loggerFor<NodeConfigurationImpl>()

View File

@ -12,7 +12,6 @@ import net.corda.testing.core.singleIdentity
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNodeParameters
import net.corda.testing.node.StartedMockNode
import org.assertj.core.api.Assertions.assertThatIllegalStateException
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -39,15 +38,15 @@ class FlowRegistrationTest {
}
@Test
fun `startup fails when two flows initiated by the same flow are registered`() {
fun `succeeds when a subclass of a flow initiated by the same flow is registered`() {
// register the same flow twice to invoke the error without causing errors in other tests
responder.registerInitiatedFlow(Responder::class.java)
assertThatIllegalStateException().isThrownBy { responder.registerInitiatedFlow(Responder::class.java) }
responder.registerInitiatedFlow(Responder1::class.java)
responder.registerInitiatedFlow(Responder1Subclassed::class.java)
}
@Test
fun `a single initiated flow can be registered without error`() {
responder.registerInitiatedFlow(Responder::class.java)
responder.registerInitiatedFlow(Responder1::class.java)
val result = initiator.startFlow(Initiator(responder.info.singleIdentity()))
mockNetwork.runNetwork()
assertNotNull(result.get())
@ -63,7 +62,38 @@ class Initiator(val party: Party) : FlowLogic<String>() {
}
@InitiatedBy(Initiator::class)
private class Responder(val session: FlowSession) : FlowLogic<Unit>() {
private open class Responder1(val session: FlowSession) : FlowLogic<Unit>() {
open fun getPayload(): String {
return "whats up"
}
@Suspendable
override fun call() {
session.receive<String>().unwrap { it }
session.send("What's up")
}
}
@InitiatedBy(Initiator::class)
private open class Responder2(val session: FlowSession) : FlowLogic<Unit>() {
open fun getPayload(): String {
return "whats up"
}
@Suspendable
override fun call() {
session.receive<String>().unwrap { it }
session.send("What's up")
}
}
@InitiatedBy(Initiator::class)
private class Responder1Subclassed(session: FlowSession) : Responder1(session) {
override fun getPayload(): String {
return "im subclassed! that's what's up!"
}
@Suspendable
override fun call() {
session.receive<String>().unwrap { it }

View File

@ -0,0 +1,110 @@
package net.corda.node.internal
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.node.services.config.FlowOverride
import net.corda.node.services.config.FlowOverrideConfig
import org.hamcrest.CoreMatchers.`is`
import org.hamcrest.CoreMatchers.instanceOf
import org.junit.Assert
import org.junit.Test
import org.mockito.Mockito
import java.lang.IllegalStateException
private val marker = "This is a special marker"
class NodeFlowManagerTest {
@InitiatingFlow
class Init : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
open class Resp(val otherSesh: FlowSession) : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
class Resp2(val otherSesh: FlowSession) : FlowLogic<Unit>() {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
open class RespSub(sesh: FlowSession) : Resp(sesh) {
override fun call() {
TODO("not implemented")
}
}
@InitiatedBy(Init::class)
class RespSubSub(sesh: FlowSession) : RespSub(sesh) {
override fun call() {
TODO("not implemented")
}
}
@Test(expected = IllegalStateException::class)
fun `should fail to validate if more than one registration with equal weight`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java)
nodeFlowManager.validateRegistrations()
}
@Test()
fun `should allow registration of flows with different weights`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java)
nodeFlowManager.validateRegistrations()
val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
val flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java)))
}
@Test()
fun `should allow updating of registered responder at runtime`() {
val nodeFlowManager = NodeFlowManager()
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSub::class.java)
nodeFlowManager.validateRegistrations()
var factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
var flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSub::class.java)))
// update
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java)
nodeFlowManager.validateRegistrations()
factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(RespSubSub::class.java)))
}
@Test
fun `should allow an override to be specified`() {
val nodeFlowManager = NodeFlowManager(FlowOverrideConfig(listOf(FlowOverride(Init::class.qualifiedName!!, Resp::class.qualifiedName!!))))
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, Resp2::class.java)
nodeFlowManager.registerInitiatedFlow(Init::class.java, RespSubSub::class.java)
nodeFlowManager.validateRegistrations()
val factory = nodeFlowManager.getFlowFactoryForInitiatingFlow(Init::class.java)!!
val flow = factory.createFlow(Mockito.mock(FlowSession::class.java))
Assert.assertThat(flow, `is`(instanceOf(Resp::class.java)))
}
}

View File

@ -149,7 +149,7 @@ class NodeTest {
}
}
private fun createConfig(nodeName: CordaX500Name): NodeConfiguration {
private fun createConfig(nodeName: CordaX500Name): NodeConfigurationImpl {
val fakeAddress = NetworkHostAndPort("0.1.2.3", 456)
return NodeConfigurationImpl(
baseDirectory = temporaryFolder.root.toPath(),
@ -167,7 +167,8 @@ class NodeTest {
flowTimeout = FlowTimeoutConfiguration(timeout = Duration.ZERO, backoffBase = 1.0, maxRestartCount = 1),
rpcSettings = NodeRpcSettings(address = fakeAddress, adminAddress = null, ssl = null),
messagingServerAddress = null,
notary = null
notary = null,
flowOverrides = FlowOverrideConfig(listOf())
)
}

View File

@ -56,7 +56,7 @@ class JarScanningCordappLoaderTest {
val actualCordapp = loader.cordapps.single()
assertThat(actualCordapp.contractClassNames).isEqualTo(listOf(isolatedContractId))
assertThat(actualCordapp.initiatedFlows.single().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor")
assertThat(actualCordapp.initiatedFlows.first().name).isEqualTo("net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Acceptor")
assertThat(actualCordapp.rpcFlows).isEmpty()
assertThat(actualCordapp.schedulableFlows).isEmpty()
assertThat(actualCordapp.services).isEmpty()
@ -74,7 +74,7 @@ class JarScanningCordappLoaderTest {
assertThat(loader.cordapps).isNotEmpty
val actualCordapp = loader.cordapps.single { !it.initiatedFlows.isEmpty() }
assertThat(actualCordapp.initiatedFlows).first().hasSameClassAs(DummyFlow::class.java)
assertThat(actualCordapp.initiatedFlows.first()).hasSameClassAs(DummyFlow::class.java)
assertThat(actualCordapp.rpcFlows).first().hasSameClassAs(DummyRPCFlow::class.java)
assertThat(actualCordapp.schedulableFlows).first().hasSameClassAs(DummySchedulableFlow::class.java)
}

View File

@ -172,8 +172,8 @@ class NodeConfigurationImplTest {
val errors = configuration.validate()
assertThat(errors).hasOnlyOneElementSatisfying {
error -> error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously")
assertThat(errors).hasOnlyOneElementSatisfying { error ->
error.contains("Cannot configure both compatibilityZoneUrl and networkServices simultaneously")
}
}
@ -268,7 +268,8 @@ class NodeConfigurationImplTest {
noLocalShell = false,
rpcSettings = rpcSettings,
crlCheckSoftFail = true,
tlsCertCrlDistPoint = null
tlsCertCrlDistPoint = null,
flowOverrides = FlowOverrideConfig(listOf())
)
}
}

View File

@ -26,7 +26,6 @@ import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.persistence.checkpoints
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
@ -116,7 +115,7 @@ class FlowFrameworkTests {
@Test
fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
val flow = ReceiveFlow(bob)
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
@ -134,7 +133,7 @@ class FlowFrameworkTests {
@Test
fun `both sides do a send as their first IO request`() {
bobNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) }
bobNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) }
aliceNode.services.startFlow(PingPongFlow(bob, 10L))
mockNet.runNetwork()
@ -151,7 +150,7 @@ class FlowFrameworkTests {
@Test
fun `other side ends before doing expected send`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { NoOpFlow() }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
@ -161,7 +160,7 @@ class FlowFrameworkTests {
@Test
fun `receiving unexpected session end before entering sendAndReceive`() {
bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
bobNode.registerCordappFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter {
it.message is ExistingSessionMessage && it.message.payload === EndSessionMessage
@ -176,7 +175,7 @@ class FlowFrameworkTests {
@Test
fun `FlowException thrown on other side`() {
val erroringFlow = bobNode.registerFlowFactory(ReceiveFlow::class) {
val erroringFlow = bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { MyFlowException("Nothing useful") }
}
val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps }
@ -240,7 +239,7 @@ class FlowFrameworkTests {
}
}
bobNode.registerFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") }
bobNode.registerCordappFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") }
val resultFuture = aliceNode.services.startFlow(RetryOnExceptionFlow(bob)).resultFuture
mockNet.runNetwork()
assertThat(resultFuture.getOrThrow()).isEqualTo("Hello")
@ -248,7 +247,7 @@ class FlowFrameworkTests {
@Test
fun `serialisation issue in counterparty`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(NonSerialisableData(1), it) }
val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
@ -258,7 +257,7 @@ class FlowFrameworkTests {
@Test
fun `FlowException has non-serialisable object`() {
bobNode.registerFlowFactory(ReceiveFlow::class) {
bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { NonSerialisableFlowException(NonSerialisableData(1)) }
}
val result = aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
@ -275,7 +274,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand(alice.owningKey))
val stx = aliceNode.services.signInitialTransaction(ptx)
val committerFiber = aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) {
val committerFiber = aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it)
}.map { it.stateMachine }.map { uncheckedCast<FlowStateMachine<*>, FlowStateMachine<Any>>(it) }
val waiterStx = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture
@ -290,7 +289,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand())
val stx = aliceNode.services.signInitialTransaction(ptx)
aliceNode.registerFlowFactory(WaitingFlows.Waiter::class) {
aliceNode.registerCordappFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) { throw Exception("Error") }
}
val waiter = bobNode.services.startFlow(WaitingFlows.Waiter(stx, alice)).resultFuture
@ -307,7 +306,7 @@ class FlowFrameworkTests {
.addCommand(dummyCommand(alice.owningKey))
val stx = aliceNode.services.signInitialTransaction(ptx)
aliceNode.registerFlowFactory(VaultQueryFlow::class) {
aliceNode.registerCordappFlowFactory(VaultQueryFlow::class) {
WaitingFlows.Committer(it)
}
val result = bobNode.services.startFlow(VaultQueryFlow(stx, alice)).resultFuture
@ -318,7 +317,7 @@ class FlowFrameworkTests {
@Test
fun `customised client flow`() {
val receiveFlowFuture = bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) }
val receiveFlowFuture = bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) }
aliceNode.services.startFlow(CustomSendFlow("Hello", bob)).resultFuture
mockNet.runNetwork()
assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello")
@ -333,7 +332,7 @@ class FlowFrameworkTests {
@Test
fun `upgraded initiating flow`() {
bobNode.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) }
bobNode.registerCordappFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { InitiatedSendFlow("Old initiated", it) }
val result = aliceNode.services.startFlow(UpgradedFlow(bob)).resultFuture
mockNet.runNetwork()
assertThat(receivedSessionMessages).startsWith(
@ -347,7 +346,7 @@ class FlowFrameworkTests {
@Test
fun `upgraded initiated flow`() {
bobNode.registerFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) }
bobNode.registerCordappFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) }
val initiatingFlow = SendFlow("Old initiating", bob)
val flowInfo = aliceNode.services.startFlow(initiatingFlow).resultFuture
mockNet.runNetwork()
@ -387,7 +386,7 @@ class FlowFrameworkTests {
@Test
fun `single inlined sub-flow`() {
bobNode.registerFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) }
bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) }
val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture
mockNet.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -395,7 +394,7 @@ class FlowFrameworkTests {
@Test
fun `double inlined sub-flow`() {
bobNode.registerFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) }
bobNode.registerCordappFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) }
val result = aliceNode.services.startFlow(SendAndReceiveFlow(bob, "Hello")).resultFuture
mockNet.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -403,7 +402,7 @@ class FlowFrameworkTests {
@Test
fun `non-FlowException thrown on other side`() {
val erroringFlowFuture = bobNode.registerFlowFactory(ReceiveFlow::class) {
val erroringFlowFuture = bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { Exception("evil bug!") }
}
val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps }
@ -507,8 +506,8 @@ class FlowFrameworkTripartyTests {
@Test
fun `sending to multiple parties`() {
bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
charlieNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork()
@ -538,8 +537,8 @@ class FlowFrameworkTripartyTests {
fun `receiving from multiple parties`() {
val bobPayload = "Test 1"
val charliePayload = "Test 2"
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
charlieNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
aliceNode.services.startFlow(multiReceiveFlow)
aliceNode.internals.acceptableLiveFiberCountOnStop = 1
@ -564,8 +563,8 @@ class FlowFrameworkTripartyTests {
@Test
fun `FlowException only propagated to parent`() {
charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
@ -577,9 +576,9 @@ class FlowFrameworkTripartyTests {
// Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
// onto Charlie which will throw the exception
val node2Fiber = bobNode
.registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.map { it.stateMachine }
charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
mockNet.runNetwork()
@ -630,6 +629,8 @@ class FlowFrameworkPersistenceTests {
private lateinit var notaryIdentity: Party
private lateinit var alice: Party
private lateinit var bob: Party
private lateinit var aliceFlowManager: MockNodeFlowManager
private lateinit var bobFlowManager: MockNodeFlowManager
@Before
fun start() {
@ -637,8 +638,11 @@ class FlowFrameworkPersistenceTests {
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = RoundRobin()
)
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
aliceFlowManager = MockNodeFlowManager()
bobFlowManager = MockNodeFlowManager()
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
@ -664,7 +668,7 @@ class FlowFrameworkPersistenceTests {
@Test
fun `flow restarted just after receiving payload`() {
bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
aliceNode.services.startFlow(SendFlow("Hello", bob))
// We push through just enough messages to get only the payload sent
@ -679,7 +683,7 @@ class FlowFrameworkPersistenceTests {
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
@ -694,7 +698,7 @@ class FlowFrameworkPersistenceTests {
var sentCount = 0
mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
val secondFlow = charlieNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
mockNet.runNetwork()
val charlie = charlieNode.info.singleIdentity()
@ -802,23 +806,14 @@ private infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, Sessi
private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
val isPayloadTransfer: Boolean
get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to"
}
private inline fun <reified P : FlowLogic<*>> TestStartedNode.registerFlowFactory(
initiatingFlowClass: KClass<out FlowLogic<*>>,
initiatedFlowVersion: Int = 1,
noinline flowFactory: (FlowSession) -> P): CordaFuture<P> {
val observable = registerFlowFactory(
initiatingFlowClass.java,
InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory),
P::class.java,
track = true)
return observable.toFuture()
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
@ -1061,7 +1056,8 @@ private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val ot
constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession)
@Suspendable
override fun call(): Any = (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
override fun call(): Any = (otherPartySession
?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
}
private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@ -1098,4 +1094,4 @@ private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<N
exceptionThrown = exception()
throw exceptionThrown
}
}
}

View File

@ -3,10 +3,7 @@ package net.corda.node.services.vault
import co.paralleluniverse.fibers.Suspendable
import com.nhaarman.mockito_kotlin.*
import net.corda.core.contracts.*
import net.corda.core.flows.FinalityFlow
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.*
import net.corda.core.identity.AbstractParty
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.packageName
@ -23,14 +20,12 @@ import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.VaultServiceInternal
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.internal.cordappsForPackages
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.cordappsForPackages
import net.corda.testing.node.internal.startFlow
import org.junit.After
import org.junit.Test
@ -68,7 +63,7 @@ class NodePair(private val mockNet: InternalMockNetwork) {
private set
fun <T> communicate(clientLogic: AbstractClientLogic<T>, rebootClient: Boolean): FlowStateMachine<T> {
server.registerFlowFactory(AbstractClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, serverRunning) }, ServerLogic::class.java, false)
server.registerCoreFlowFactory(AbstractClientLogic::class.java, ServerLogic::class.java, { ServerLogic(it, serverRunning) }, false)
client.services.startFlow(clientLogic)
while (!serverRunning.get()) mockNet.runNetwork(1)
if (rebootClient) {