Introducing InitiatedBy annotation to be used on initiated flows to simplify flow registration.

This removes the need to do manual registration using the PluginServiceHub. As a result CordaPluginRegistry.servicePlugins is no longer needed. For oracles and services there is a CorDappService annotation.

I've also fixed the InitiatingFlow annotation such that client flows can be customised (sub-typed) without it breaking the flow sessions.
This commit is contained in:
Shams Asari
2017-05-19 16:14:48 +01:00
parent e117ab7703
commit 329e5ff17b
66 changed files with 892 additions and 760 deletions

View File

@ -0,0 +1,82 @@
package net.corda.node
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.Futures
import net.corda.core.copyToDirectory
import net.corda.core.createDirectories
import net.corda.core.div
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.getOrThrow
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.ALICE
import net.corda.core.utilities.BOB
import net.corda.core.utilities.unwrap
import net.corda.node.driver.driver
import net.corda.node.services.startFlowPermission
import net.corda.nodeapi.User
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import java.nio.file.Paths
class CordappScanningTest {
@Test
fun `CorDapp jar in plugins directory is scanned`() {
// If the CorDapp jar does't exist then run the integrationTestClasses gradle task
val cordappJar = Paths.get(javaClass.getResource("/trader-demo.jar").toURI())
driver {
val pluginsDir = (baseDirectory(ALICE.name) / "plugins").createDirectories()
cordappJar.copyToDirectory(pluginsDir)
val user = User("u", "p", emptySet())
val alice = startNode(ALICE.name, rpcUsers = listOf(user)).getOrThrow()
val rpc = alice.rpcClientToNode().start(user.username, user.password)
// If the CorDapp wasn't scanned then SellerFlow won't have been picked up as an RPC flow
assertThat(rpc.proxy.registeredFlows()).contains("net.corda.traderdemo.flow.SellerFlow")
}
}
@Test
fun `empty plugins directory`() {
driver {
val baseDirectory = baseDirectory(ALICE.name)
(baseDirectory / "plugins").createDirectories()
startNode(ALICE.name).getOrThrow()
}
}
@Test
fun `sub-classed initiated flow pointing to the same initiating flow as its super-class`() {
val user = User("u", "p", setOf(startFlowPermission<ReceiveFlow>()))
driver(systemProperties = mapOf("net.corda.node.cordapp.scan.package" to "net.corda.node")) {
val (alice, bob) = Futures.allAsList(
startNode(ALICE.name, rpcUsers = listOf(user)),
startNode(BOB.name)).getOrThrow()
val initiatedFlowClass = alice.rpcClientToNode()
.start(user.username, user.password)
.proxy
.startFlow(::ReceiveFlow, bob.nodeInfo.legalIdentity)
.returnValue
assertThat(initiatedFlowClass.getOrThrow()).isEqualTo(SendSubClassFlow::class.java.name)
}
}
@StartableByRPC
@InitiatingFlow
class ReceiveFlow(val otherParty: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String = receive<String>(otherParty).unwrap { it }
}
@InitiatedBy(ReceiveFlow::class)
open class SendClassFlow(val otherParty: Party) : FlowLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, javaClass.name)
}
@InitiatedBy(ReceiveFlow::class)
class SendSubClassFlow(otherParty: Party) : SendClassFlow(otherParty)
}

View File

@ -6,6 +6,7 @@ import net.corda.client.rpc.CordaRPCClient
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.toBase58String
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.getOrThrow
import net.corda.core.identity.Party
@ -216,7 +217,7 @@ abstract class MQSecurityTest : NodeBasedTest() {
private fun startBobAndCommunicateWithAlice(): Party {
val bob = startNode(BOB.name).getOrThrow()
bob.services.registerServiceFlow(SendFlow::class.java, ::ReceiveFlow)
bob.registerInitiatedFlow(ReceiveFlow::class.java)
val bobParty = bob.info.legalIdentity
// Perform a protocol exchange to force the peer queue to be created
alice.services.startFlow(SendFlow(bobParty, 0)).resultFuture.getOrThrow()
@ -229,6 +230,7 @@ abstract class MQSecurityTest : NodeBasedTest() {
override fun call() = send(otherParty, payload)
}
@InitiatedBy(SendFlow::class)
private class ReceiveFlow(val otherParty: Party) : FlowLogic<Any>() {
@Suspendable
override fun call() = receive<Any>(otherParty).unwrap { it }

View File

@ -7,6 +7,8 @@ import com.google.common.util.concurrent.*
import com.typesafe.config.Config
import com.typesafe.config.ConfigRenderOptions
import net.corda.client.rpc.CordaRPCClient
import net.corda.cordform.CordformContext
import net.corda.cordform.CordformNode
import net.corda.core.*
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.appendToCommonName
@ -19,10 +21,6 @@ import net.corda.core.node.services.ServiceType
import net.corda.core.utilities.*
import net.corda.node.LOGS_DIRECTORY_NAME
import net.corda.node.services.config.*
import net.corda.node.services.config.ConfigHelper
import net.corda.node.services.config.FullNodeConfiguration
import net.corda.node.services.config.VerifierType
import net.corda.node.services.config.configOf
import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.transactions.RaftValidatingNotaryService
import net.corda.node.utilities.ServiceIdentityGenerator
@ -30,8 +28,6 @@ import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.User
import net.corda.nodeapi.config.SSLConfiguration
import net.corda.nodeapi.config.parseAs
import net.corda.cordform.CordformNode
import net.corda.cordform.CordformContext
import net.corda.core.internal.ShutdownHook
import net.corda.core.internal.addShutdownHook
import okhttp3.OkHttpClient
@ -39,6 +35,7 @@ import okhttp3.Request
import org.bouncycastle.asn1.x500.X500Name
import org.slf4j.Logger
import java.io.File
import java.io.File.pathSeparator
import java.net.*
import java.nio.file.Path
import java.nio.file.Paths
@ -614,7 +611,7 @@ class DriverDSL(
}
}
override fun baseDirectory(nodeName: X500Name) = driverDirectory / nodeName.commonName.replace(WHITESPACE, "")
override fun baseDirectory(nodeName: X500Name): Path = driverDirectory / nodeName.commonName.replace(WHITESPACE, "")
override fun startDedicatedNetworkMapService(): ListenableFuture<Unit> {
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
@ -679,6 +676,8 @@ class DriverDSL(
"-javaagent:$quasarJarPath"
val loggingLevel = if (debugPort == null) "INFO" else "DEBUG"
val pluginsDirectory = nodeConf.baseDirectory / "plugins"
ProcessUtilities.startJavaProcess(
className = "net.corda.node.Corda", // cannot directly get class for this, so just use string
arguments = listOf(
@ -686,6 +685,8 @@ class DriverDSL(
"--logging-level=$loggingLevel",
"--no-local-shell"
),
// Like the capsule, include the node's plugin directory
classpath = "${ProcessUtilities.defaultClassPath}$pathSeparator$pluginsDirectory/*",
jdwpPort = debugPort,
extraJvmArguments = extraJvmArguments,
errorLogPath = nodeConf.baseDirectory / LOGS_DIRECTORY_NAME / "error.log",

View File

@ -2,16 +2,15 @@ package net.corda.node.internal
import com.codahale.metrics.MetricRegistry
import com.google.common.annotations.VisibleForTesting
import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
import io.github.lukehutch.fastclasspathscanner.FastClasspathScanner
import io.github.lukehutch.fastclasspathscanner.scanner.ScanResult
import net.corda.core.*
import net.corda.core.crypto.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.RPCOps
@ -19,6 +18,7 @@ import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.*
import net.corda.core.node.services.*
import net.corda.core.node.services.NetworkMapCache.MapChange
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize
import net.corda.core.transactions.SignedTransaction
@ -45,7 +45,7 @@ import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.services.statemachine.flowVersion
import net.corda.node.services.statemachine.flowVersionAndInitiatingClass
import net.corda.node.services.transactions.*
import net.corda.node.services.vault.CashBalanceAsMetricsObserver
import net.corda.node.services.vault.NodeVaultService
@ -60,9 +60,11 @@ import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
import org.jetbrains.exposed.sql.Database
import org.slf4j.Logger
import rx.Observable
import java.io.IOException
import java.lang.reflect.Modifier.*
import java.net.URL
import java.net.JarURLConnection
import java.net.URI
import java.nio.file.FileAlreadyExistsException
import java.nio.file.Path
import java.nio.file.Paths
@ -73,6 +75,7 @@ import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit.SECONDS
import java.util.stream.Collectors.toList
import kotlin.collections.ArrayList
import kotlin.reflect.KClass
import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair
@ -106,7 +109,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
// low-performance prototyping period.
protected abstract val serverThread: AffinityExecutor
protected val serviceFlowFactories = ConcurrentHashMap<Class<*>, ServiceFlowInfo>()
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val flowFactories = ConcurrentHashMap<Class<out FlowLogic<*>>, InitiatedFlowFactory<*>>()
protected val partyKeys = mutableSetOf<KeyPair>()
val services = object : ServiceHubInternal() {
@ -122,6 +126,12 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
override val schemaService: SchemaService get() = schemas
override val transactionVerifierService: TransactionVerifierService get() = txVerifierService
override val auditService: AuditService get() = auditService
override fun <T : SerializeAsToken> cordaService(type: Class<T>): T {
require(type.isAnnotationPresent(CordaService::class.java)) { "${type.name} is not a Corda service" }
return cordappServices.getInstance(type) ?: throw IllegalArgumentException("Corda service ${type.name} does not exist")
}
override val rpcFlows: List<Class<out FlowLogic<*>>> get() = this@AbstractNode.rpcFlows
// Internal only
@ -131,17 +141,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
return serverThread.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerServiceFlow(initiatingFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) {
require(initiatingFlowClass !in serviceFlowFactories) {
"${initiatingFlowClass.name} has already been used to register a service flow"
}
val info = ServiceFlowInfo.CorDapp(initiatingFlowClass.flowVersion, serviceFlowFactory)
log.info("Registering service flow for ${initiatingFlowClass.name}: $info")
serviceFlowFactories[initiatingFlowClass] = info
}
override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ServiceFlowInfo? {
return serviceFlowFactories[clientFlowClass]
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatingFlowClass]
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) {
@ -167,15 +168,11 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
lateinit var scheduler: NodeSchedulerService
lateinit var schemas: SchemaService
lateinit var auditService: AuditService
val customServices: ArrayList<Any> = ArrayList()
protected val runOnStop: ArrayList<Runnable> = ArrayList()
lateinit var database: Database
protected var dbCloser: Runnable? = null
private lateinit var rpcFlows: List<Class<out FlowLogic<*>>>
/** Locates and returns a service of the given type if loaded, or throws an exception if not found. */
inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single()
var isPreviousCheckpointsPresent = false
private set
@ -217,7 +214,6 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
val tokenizableServices = makeServices()
smm = StateMachineManager(services,
listOf(tokenizableServices),
checkpointStorage,
serverThread,
database,
@ -240,22 +236,24 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
startMessagingService(rpcOps)
installCoreFlows()
fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || isStatic(modifiers))
val scanResult = scanCorDapps()
if (scanResult != null) {
val cordappServices = installCordaServices(scanResult)
tokenizableServices.addAll(cordappServices)
registerInitiatedFlows(scanResult)
rpcFlows = findRPCFlows(scanResult)
} else {
rpcFlows = emptyList()
}
val flows = scanForFlows()
rpcFlows = flows.filter { it.isUserInvokable() && it.isAnnotationPresent(StartableByRPC::class.java) } +
// Add any core flows here
listOf(ContractUpgradeFlow::class.java,
// TODO Remove all Cash flows from default list once they are split into separate CorDapp.
CashIssueFlow::class.java,
CashExitFlow::class.java,
CashPaymentFlow::class.java)
// TODO Remove this once the cash stuff is in its own CorDapp
registerInitiatedFlow(IssuerFlow.Issuer::class.java)
initUploaders()
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMapIfConfigured())
smm.start()
smm.start(tokenizableServices)
// Shut down the SMM so no Fibers are scheduled.
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
scheduler.start()
@ -264,18 +262,142 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
return this
}
private fun installCordaServices(scanResult: ScanResult): List<SerializeAsToken> {
return scanResult.getClassesWithAnnotation(SerializeAsToken::class, CordaService::class).mapNotNull {
try {
installCordaService(it)
} catch (e: NoSuchMethodException) {
log.error("${it.name}, as a Corda service, must have a constructor with a single parameter " +
"of type ${PluginServiceHub::class.java.name}")
null
} catch (e: Exception) {
log.error("Unable to install Corda service ${it.name}", e)
null
}
}
}
/**
* Use this method to install your Corda services in your tests. This is automatically done by the node when it
* starts up for all classes it finds which are annotated with [CordaService].
*/
fun <T : SerializeAsToken> installCordaService(clazz: Class<T>): T {
clazz.requireAnnotation<CordaService>()
val ctor = clazz.getDeclaredConstructor(PluginServiceHub::class.java).apply { isAccessible = true }
val service = ctor.newInstance(services)
cordappServices.putInstance(clazz, service)
log.info("Installed ${clazz.name} Corda service")
return service
}
private inline fun <reified A : Annotation> Class<*>.requireAnnotation(): A {
return requireNotNull(getDeclaredAnnotation(A::class.java)) { "$name needs to be annotated with ${A::class.java.name}" }
}
private fun registerInitiatedFlows(scanResult: ScanResult) {
scanResult
.getClassesWithAnnotation(FlowLogic::class, InitiatedBy::class)
// First group by the initiating flow class in case there are multiple mappings
.groupBy { it.requireAnnotation<InitiatedBy>().value.java }
.map { (initiatingFlow, initiatedFlows) ->
val sorted = initiatedFlows.sortedWith(FlowTypeHierarchyComparator(initiatingFlow))
if (sorted.size > 1) {
log.warn("${initiatingFlow.name} has been specified as the inititating flow by multiple flows " +
"in the same type hierarchy: ${sorted.joinToString { it.name }}. Choosing the most " +
"specific sub-type for registration: ${sorted[0].name}.")
}
sorted[0]
}
.forEach {
try {
registerInitiatedFlowInternal(it, track = false)
} catch (e: NoSuchMethodException) {
log.error("${it.name}, as an initiated flow, must have a constructor with a single parameter " +
"of type ${Party::class.java.name}")
} catch (e: Exception) {
log.error("Unable to register initiated flow ${it.name}", e)
}
}
}
private class FlowTypeHierarchyComparator(val initiatingFlow: Class<out FlowLogic<*>>) : Comparator<Class<out FlowLogic<*>>> {
override fun compare(o1: Class<out FlowLogic<*>>, o2: Class<out FlowLogic<*>>): Int {
return if (o1 == o2) {
0
} else if (o1.isAssignableFrom(o2)) {
1
} else if (o2.isAssignableFrom(o1)) {
-1
} else {
throw IllegalArgumentException("${initiatingFlow.name} has been specified as the initiating flow by " +
"both ${o1.name} and ${o2.name}")
}
}
}
/**
* Use this method to register your initiated flows in your tests. This is automatically done by the node when it
* starts up for all [FlowLogic] classes it finds which are annotated with [InitiatedBy].
* @return An [Observable] of the initiated flows started by counter-parties.
*/
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>): Observable<T> {
return registerInitiatedFlowInternal(initiatedFlowClass, track = true)
}
private fun <F : FlowLogic<*>> registerInitiatedFlowInternal(initiatedFlow: Class<F>, track: Boolean): Observable<F> {
val ctor = initiatedFlow.getDeclaredConstructor(Party::class.java).apply { isAccessible = true }
val initiatingFlow = initiatedFlow.requireAnnotation<InitiatedBy>().value.java
val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass
require(classWithAnnotation == initiatingFlow) {
"${InitiatingFlow::class.java.name} must be annotated on ${initiatingFlow.name} and not on a super-type"
}
val flowFactory = InitiatedFlowFactory.CorDapp(version, { ctor.newInstance(it) })
val observable = registerFlowFactory(initiatingFlow, flowFactory, initiatedFlow, track)
log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)")
return observable
}
@VisibleForTesting
fun <F : FlowLogic<*>> registerFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>,
flowFactory: InitiatedFlowFactory<F>,
initiatedFlowClass: Class<F>,
track: Boolean): Observable<F> {
val observable = if (track) {
smm.changes.filter { it is StateMachineManager.Change.Add }.map { it.logic }.ofType(initiatedFlowClass)
} else {
Observable.empty()
}
flowFactories[initiatingFlowClass] = flowFactory
return observable
}
private fun findRPCFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || isStatic(modifiers))
}
return scanResult.getClassesWithAnnotation(FlowLogic::class, StartableByRPC::class).filter { it.isUserInvokable() } +
// Add any core flows here
listOf(
ContractUpgradeFlow::class.java,
// TODO Remove all Cash flows from default list once they are split into separate CorDapp.
CashIssueFlow::class.java,
CashExitFlow::class.java,
CashPaymentFlow::class.java)
}
/**
* Installs a flow that's core to the Corda platform. Unlike CorDapp flows which are versioned individually using
* [InitiatingFlow.version], core flows have the same version as the node's platform version. To cater for backwards
* compatibility [serviceFlowFactory] provides a second parameter which is the platform version of the initiating party.
* compatibility [flowFactory] provides a second parameter which is the platform version of the initiating party.
* @suppress
*/
@VisibleForTesting
fun installCoreFlow(clientFlowClass: KClass<out FlowLogic<*>>, serviceFlowFactory: (Party, Int) -> FlowLogic<*>) {
require(clientFlowClass.java.flowVersion == 1) {
fun installCoreFlow(clientFlowClass: KClass<out FlowLogic<*>>, flowFactory: (Party, Int) -> FlowLogic<*>) {
require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) {
"${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version"
}
serviceFlowFactories[clientFlowClass.java] = ServiceFlowInfo.Core(serviceFlowFactory)
flowFactories[clientFlowClass.java] = InitiatedFlowFactory.Core(flowFactory)
log.debug { "Installed core flow ${clientFlowClass.java.name}" }
}
@ -313,61 +435,62 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
val tokenizableServices = mutableListOf(storage, net, vault, keyManagement, identity, platformClock, scheduler)
makeAdvertisedServices(tokenizableServices)
customServices.clear()
customServices.addAll(makePluginServices(tokenizableServices))
initUploaders(storageServices)
return tokenizableServices
}
private fun scanForFlows(): List<Class<out FlowLogic<*>>> {
val pluginsDir = configuration.baseDirectory / "plugins"
log.info("Scanning plugins in $pluginsDir ...")
if (!pluginsDir.exists()) return emptyList()
val pluginJars = pluginsDir.list {
it.filter { it.isRegularFile() && it.toString().endsWith(".jar") }.toArray()
private fun scanCorDapps(): ScanResult? {
val scanPackage = System.getProperty("net.corda.node.cordapp.scan.package")
val paths = if (scanPackage != null) {
// This is purely for integration tests so that classes defined in the test can automatically be picked up
check(configuration.devMode) { "Package scanning can only occur in dev mode" }
val resource = scanPackage.replace('.', '/')
javaClass.classLoader.getResources(resource)
.asSequence()
.map {
val uri = if (it.protocol == "jar") {
(it.openConnection() as JarURLConnection).jarFileURL.toURI()
} else {
URI(it.toExternalForm().removeSuffix(resource))
}
Paths.get(uri)
}
.toList()
} else {
val pluginsDir = configuration.baseDirectory / "plugins"
if (!pluginsDir.exists()) return null
pluginsDir.list {
it.filter { it.isRegularFile() && it.toString().endsWith(".jar") }.collect(toList())
}
}
if (pluginJars.isEmpty()) return emptyList()
log.info("Scanning CorDapps in $paths")
val scanResult = FastClasspathScanner().overrideClasspath(*pluginJars).scan() // This will only scan the plugin jars and nothing else
// This will only scan the plugin jars and nothing else
return if (paths.isNotEmpty()) FastClasspathScanner().overrideClasspath(paths).scan() else null
}
fun loadFlowClass(className: String): Class<out FlowLogic<*>>? {
private fun <T : Any> ScanResult.getClassesWithAnnotation(type: KClass<T>, annotation: KClass<out Annotation>): List<Class<out T>> {
fun loadClass(className: String): Class<out T>? {
return try {
// TODO Make sure this is loaded by the correct class loader
@Suppress("UNCHECKED_CAST")
Class.forName(className, false, javaClass.classLoader) as Class<out FlowLogic<*>>
Class.forName(className, false, javaClass.classLoader).asSubclass(type.java)
} catch (e: ClassCastException) {
log.warn("As $className is annotated with ${annotation.qualifiedName} it must be a sub-type of ${type.java.name}")
null
} catch (e: Exception) {
log.warn("Unable to load flow class $className", e)
log.warn("Unable to load class $className", e)
null
}
}
val flowClasses = scanResult.getNamesOfSubclassesOf(FlowLogic::class.java)
.mapNotNull { loadFlowClass(it) }
return getNamesOfClassesWithAnnotation(annotation.java)
.mapNotNull { loadClass(it) }
.filterNot { isAbstract(it.modifiers) }
fun URL.pluginName(): String {
return try {
Paths.get(toURI()).fileName.toString()
} catch (e: Exception) {
toString()
}
}
flowClasses.groupBy {
scanResult.classNameToClassInfo[it.name]!!.classpathElementURLs.first()
}.forEach { url, classes ->
log.info("Found flows in plugin ${url.pluginName()}: ${classes.joinToString { it.name }}")
}
return flowClasses
}
private fun initUploaders(storageServices: Pair<TxWritableStorageService, CheckpointStorage>) {
val uploaders: List<FileUploader> = listOf(storageServices.first.attachments as NodeAttachmentService) +
customServices.filterIsInstance(AcceptsFileUpload::class.java)
private fun initUploaders() {
val uploaders: List<FileUploader> = listOf(storage.attachments as NodeAttachmentService) +
cordappServices.values.filterIsInstance(AcceptsFileUpload::class.java)
(storage as StorageServiceImpl).initUploaders(uploaders)
}
@ -433,12 +556,6 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
}
}
private fun makePluginServices(tokenizableServices: MutableList<Any>): List<Any> {
val pluginServices = pluginRegistries.flatMap { it.servicePlugins }.map { it.apply(services) }
tokenizableServices.addAll(pluginServices)
return pluginServices
}
/**
* Run any tasks that are needed to ensure the node is in a correct state before running start().
*/
@ -658,11 +775,6 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
}
}
sealed class ServiceFlowInfo {
data class Core(val factory: (Party, Int) -> FlowLogic<*>) : ServiceFlowInfo()
data class CorDapp(val version: Int, val factory: (Party) -> FlowLogic<*>) : ServiceFlowInfo()
}
private class KeyStoreWrapper(private val storePath: Path, private val storePassword: String) {
private val keyStore = KeyStoreUtilities.loadKeyStore(storePath, storePassword)

View File

@ -0,0 +1,27 @@
package net.corda.node.internal
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.node.services.statemachine.SessionInit
interface InitiatedFlowFactory<out F : FlowLogic<*>> {
fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F
data class Core<out F : FlowLogic<*>>(val factory: (Party, Int) -> F) : InitiatedFlowFactory<F> {
override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F {
return factory(otherParty, platformVersion)
}
}
data class CorDapp<out F : FlowLogic<*>>(val version: Int, val factory: (Party) -> F) : InitiatedFlowFactory<F> {
override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F {
// TODO Add support for multiple versions of the same flow when CorDapps are loaded in separate class loaders
if (sessionInit.flowVerison == version) return factory(otherParty)
throw SessionRejectException(
"Version not supported",
"Version mismatch - ${sessionInit.initiatingFlowClass} is only registered for version $version")
}
}
}
class SessionRejectException(val rejectMessage: String, val logMessage: String) : Exception()

View File

@ -13,7 +13,7 @@ import net.corda.core.node.services.TxWritableStorageService
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.loggerFor
import net.corda.node.internal.ServiceFlowInfo
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
import net.corda.node.services.statemachine.FlowStateMachineImpl
@ -47,7 +47,6 @@ interface NetworkMapCacheInternal : NetworkMapCache {
/** For testing where the network map cache is manipulated marks the service as immediately ready. */
@VisibleForTesting
fun runWithoutMapService()
}
@CordaSerializable
@ -93,7 +92,6 @@ abstract class ServiceHubInternal : PluginServiceHub {
* Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator]
* defaults to [FlowInitiator.RPC] with username "Only For Testing".
*/
// TODO Move it to test utils.
@VisibleForTesting
fun <T> startFlow(logic: FlowLogic<T>): FlowStateMachine<T> = startFlow(logic, FlowInitiator.RPC("Only For Testing"))
@ -103,7 +101,6 @@ abstract class ServiceHubInternal : PluginServiceHub {
*/
abstract fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachineImpl<T>
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow.
* Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started,
@ -122,5 +119,5 @@ abstract class ServiceHubInternal : PluginServiceHub {
return startFlow(logic, flowInitiator)
}
abstract fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ServiceFlowInfo?
abstract fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>?
}

View File

@ -8,9 +8,9 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.ErrorOr
import net.corda.core.abbreviate
import net.corda.core.identity.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.random63BitValue
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
@ -27,7 +27,6 @@ import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.lang.reflect.Modifier
import java.sql.Connection
import java.sql.SQLException
import java.util.*
@ -322,9 +321,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Initiating a new session with $otherParty" }
val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty), retryable)
openSessions[Pair(sessionFlow, otherParty)] = session
// We get the top-most concrete class object to cater for the case where the client flow is customised via a sub-class
val clientFlowClass = sessionFlow.topConcreteFlowClass
val sessionInit = SessionInit(session.ourSessionId, clientFlowClass, clientFlowClass.flowVersion, firstPayload)
val (version, initiatingFlowClass) = sessionFlow.javaClass.flowVersionAndInitiatingClass
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass, version, firstPayload)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()
@ -332,15 +330,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return session
}
@Suppress("UNCHECKED_CAST")
private val FlowLogic<*>.topConcreteFlowClass: Class<out FlowLogic<*>> get() {
var current: Class<out FlowLogic<*>> = javaClass
while (!Modifier.isAbstract(current.superclass.modifiers)) {
current = current.superclass as Class<out FlowLogic<*>>
}
return current
}
@Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
@ -460,10 +449,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
}
val Class<out FlowLogic<*>>.flowVersion: Int get() {
val annotation = requireNotNull(getAnnotation(InitiatingFlow::class.java)) {
"$name as the initiating flow must be annotated with ${InitiatingFlow::class.java.name}"
@Suppress("UNCHECKED_CAST")
val Class<out FlowLogic<*>>.flowVersionAndInitiatingClass: Pair<Int, Class<out FlowLogic<*>>> get() {
var current: Class<*> = this
var found: Pair<Int, Class<out FlowLogic<*>>>? = null
while (true) {
val annotation = current.getDeclaredAnnotation(InitiatingFlow::class.java)
if (annotation != null) {
if (found != null) throw IllegalArgumentException("${InitiatingFlow::class.java.name} can only be annotated once")
require(annotation.version > 0) { "Flow versions have to be greater or equal to 1" }
found = annotation.version to (current as Class<out FlowLogic<*>>)
}
current = current.superclass
?: return found
?: throw IllegalArgumentException("$name as an initiating flow must be annotated with ${InitiatingFlow::class.java.name}")
}
require(annotation.version > 0) { "Flow versions have to be greater or equal to 1" }
return annotation.version
}

View File

@ -15,7 +15,7 @@ import net.corda.core.utilities.UntrustworthyData
interface SessionMessage
data class SessionInit(val initiatorSessionId: Long,
val clientFlowClass: Class<out FlowLogic<*>>,
val initiatingFlowClass: Class<out FlowLogic<*>>,
val flowVerison: Int,
val firstPayload: Any?) : SessionMessage

View File

@ -15,14 +15,14 @@ import com.google.common.collect.HashMultimap
import com.google.common.util.concurrent.ListenableFuture
import io.requery.util.CloseableIterator
import net.corda.core.*
import net.corda.core.identity.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.serialization.*
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.trace
import net.corda.node.internal.ServiceFlowInfo
import net.corda.node.internal.SessionRejectException
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
@ -61,7 +61,6 @@ import javax.annotation.concurrent.ThreadSafe
*/
@ThreadSafe
class StateMachineManager(val serviceHub: ServiceHubInternal,
tokenizableServices: List<Any>,
val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor,
val database: Database,
@ -147,7 +146,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
// Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryoPool, serviceHub)
private lateinit var serializationContext: SerializeAsTokenContext
/** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */
fun <P : FlowLogic<T>, T> findStateMachines(flowClass: Class<P>): List<Pair<P, ListenableFuture<T>>> {
@ -171,7 +170,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
*/
val changes: Observable<Change> = mutex.content.changesPublisher.wrapWithDatabaseTransaction()
fun start() {
fun start(tokenizableServices: List<Any>) {
serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryoPool, serviceHub)
restoreFibersFromCheckpoints()
listenToLedgerTransactions()
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
@ -345,28 +345,15 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(otherPartySessionId, message))
val serviceFlowInfo = serviceHub.getServiceFlowFactory(sessionInit.clientFlowClass)
if (serviceFlowInfo == null) {
logger.warn("${sessionInit.clientFlowClass} has not been registered with a service flow: $sessionInit")
sendSessionReject("${sessionInit.clientFlowClass.name} has not been registered with a service flow")
val initiatedFlowFactory = serviceHub.getFlowFactory(sessionInit.initiatingFlowClass)
if (initiatedFlowFactory == null) {
logger.warn("${sessionInit.initiatingFlowClass} has not been registered: $sessionInit")
sendSessionReject("${sessionInit.initiatingFlowClass.name} has not been registered with a service flow")
return
}
val session = try {
val flow = when (serviceFlowInfo) {
is ServiceFlowInfo.CorDapp -> {
// TODO Add support for multiple versions of the same flow when CorDapps are loaded in separate class loaders
if (sessionInit.flowVerison != serviceFlowInfo.version) {
logger.warn("Version mismatch - ${sessionInit.clientFlowClass} is only registered for version " +
"${serviceFlowInfo.version}: $sessionInit")
sendSessionReject("Version not supported")
return
}
serviceFlowInfo.factory(sender)
}
is ServiceFlowInfo.Core -> serviceFlowInfo.factory(sender, receivedMessage.platformVersion)
}
val flow = initiatedFlowFactory.createFlow(receivedMessage.platformVersion, sender, sessionInit)
val fiber = createFiber(flow, FlowInitiator.Peer(sender))
val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId))
if (sessionInit.firstPayload != null) {
@ -376,6 +363,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
fiber.openSessions[Pair(flow, sender)] = session
updateCheckpoint(fiber)
session
} catch (e: SessionRejectException) {
logger.warn("${e.logMessage}: $sessionInit")
sendSessionReject(e.rejectMessage)
return
} catch (e: Exception) {
logger.warn("Couldn't start flow session from $sessionInit", e)
sendSessionReject("Unable to establish session")
@ -383,7 +374,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.clientFlowClass.name}" }
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass.name}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber)
}

View File

@ -1,20 +0,0 @@
package net.corda.node.internal
import net.corda.core.createDirectories
import net.corda.core.crypto.commonName
import net.corda.core.div
import net.corda.core.getOrThrow
import net.corda.core.utilities.ALICE
import net.corda.core.utilities.WHITESPACE
import net.corda.testing.node.NodeBasedTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class NodeTest : NodeBasedTest() {
@Test
fun `empty plugins directory`() {
val baseDirectory = baseDirectory(ALICE.name)
(baseDirectory / "plugins").createDirectories()
startNode(ALICE.name).getOrThrow()
}
}

View File

@ -4,24 +4,18 @@ import co.paralleluniverse.fibers.Suspendable
import net.corda.contracts.CommercialPaper
import net.corda.contracts.asset.*
import net.corda.contracts.testing.fillWithSomeTestCash
import net.corda.core.*
import net.corda.core.contracts.*
import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sign
import net.corda.core.days
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StateMachineRunId
import net.corda.core.getOrThrow
import net.corda.core.flows.*
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty
import net.corda.core.identity.Party
import net.corda.core.identity.AbstractParty
import net.corda.core.map
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.*
import net.corda.core.rootCause
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
@ -86,9 +80,10 @@ class TwoPartyTradeFlowTests {
net = MockNetwork(false, true)
ledger {
val notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name)
val basketOfNodes = net.createSomeNodes(2)
val notaryNode = basketOfNodes.notaryNode
val aliceNode = basketOfNodes.partyNodes[0]
val bobNode = basketOfNodes.partyNodes[1]
aliceNode.disableDBCloseOnStop()
bobNode.disableDBCloseOnStop()
@ -137,8 +132,7 @@ class TwoPartyTradeFlowTests {
aliceNode.disableDBCloseOnStop()
bobNode.disableDBCloseOnStop()
val cashStates =
bobNode.database.transaction {
val cashStates = bobNode.database.transaction {
bobNode.services.fillWithSomeTestCash(2000.DOLLARS, notaryNode.info.notaryIdentity, 3, 3)
}
@ -239,7 +233,7 @@ class TwoPartyTradeFlowTests {
}, true, BOB.name)
// Find the future representing the result of this state machine again.
val bobFuture = bobNode.smm.findStateMachines(Buyer::class.java).single().second
val bobFuture = bobNode.smm.findStateMachines(BuyerAcceptor::class.java).single().second
// And off we go again.
net.runNetwork()
@ -489,25 +483,42 @@ class TwoPartyTradeFlowTests {
sellerNode: MockNetwork.MockNode,
buyerNode: MockNetwork.MockNode,
assetToSell: StateAndRef<OwnableState>): RunResult {
@InitiatingFlow
class SellerRunnerFlow(val buyer: Party, val notary: NodeInfo) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction = subFlow(Seller(
buyer,
notary,
assetToSell,
1000.DOLLARS,
serviceHub.legalIdentityKey))
}
sellerNode.services.identityService.registerIdentity(buyerNode.info.legalIdentity)
buyerNode.services.identityService.registerIdentity(sellerNode.info.legalIdentity)
val buyerFuture = buyerNode.initiateSingleShotFlow(SellerRunnerFlow::class) { otherParty ->
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
}.map { it.stateMachine }
val seller = SellerRunnerFlow(buyerNode.info.legalIdentity, notaryNode.info)
val sellerResultFuture = sellerNode.services.startFlow(seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.stateMachine.id)
val buyerFlows: Observable<BuyerAcceptor> = buyerNode.registerInitiatedFlow(BuyerAcceptor::class.java)
val firstBuyerFiber = buyerFlows.toFuture().map { it.stateMachine }
val seller = SellerInitiator(buyerNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS)
val sellerResult = sellerNode.services.startFlow(seller).resultFuture
return RunResult(firstBuyerFiber, sellerResult, seller.stateMachine.id)
}
@InitiatingFlow
class SellerInitiator(val buyer: Party,
val notary: NodeInfo,
val assetToSell: StateAndRef<OwnableState>,
val price: Amount<Currency>) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
send(buyer, Pair(notary.notaryIdentity, price))
return subFlow(Seller(
buyer,
notary,
assetToSell,
price,
serviceHub.legalIdentityKey))
}
}
@InitiatedBy(SellerInitiator::class)
class BuyerAcceptor(val seller: Party) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val (notary, price) = receive<Pair<Party, Amount<Currency>>>(seller).unwrap {
require(serviceHub.networkMapCache.isNotary(it.first)) { "${it.first} is not a notary" }
it
}
return subFlow(Buyer(seller, notary, price, CommercialPaper.State::class.java))
}
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(

View File

@ -3,11 +3,11 @@ package net.corda.node.services
import com.codahale.metrics.MetricRegistry
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.transactions.SignedTransaction
import net.corda.node.internal.ServiceFlowInfo
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.serialization.NodeClock
import net.corda.node.services.api.*
import net.corda.node.services.messaging.MessagingService
@ -67,11 +67,11 @@ open class MockServiceHubInternal(
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T : SerializeAsToken> cordaService(type: Class<T>): T = throw UnsupportedOperationException()
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator): FlowStateMachineImpl<T> {
return smm.executor.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerServiceFlow(initiatingFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) = Unit
override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ServiceFlowInfo? = null
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? = null
}

View File

@ -92,13 +92,13 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
}
scheduler = NodeSchedulerService(services, database, schedulerGatedExecutor)
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
val mockSMM = StateMachineManager(services, listOf(services, scheduler), DBCheckpointStorage(), smmExecutor, database)
val mockSMM = StateMachineManager(services, DBCheckpointStorage(), smmExecutor, database)
mockSMM.changes.subscribe { change ->
if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) {
smmHasRemovedAllFlows.countDown()
}
}
mockSMM.start()
mockSMM.start(listOf(services, scheduler))
services.smm = mockSMM
scheduler.start()
}

View File

@ -6,9 +6,10 @@ import net.corda.core.contracts.Amount
import net.corda.core.contracts.Issued
import net.corda.core.contracts.TransactionType
import net.corda.core.contracts.USD
import net.corda.core.identity.Party
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.node.services.unconsumedStates
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.DUMMY_NOTARY
@ -86,16 +87,20 @@ class DataVendingServiceTests {
}
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.registerServiceFlow(clientFlowClass = NotifyTxFlow::class, serviceFlowFactory = ::NotifyTransactionHandler)
walletServiceNode.registerInitiatedFlow(InitiateNotifyTxFlow::class.java)
services.startFlow(NotifyTxFlow(walletServiceNode.info.legalIdentity, tx))
network.runNetwork()
}
@InitiatingFlow
private class NotifyTxFlow(val otherParty: Party, val stx: SignedTransaction) : FlowLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, NotifyTxRequest(stx))
}
@InitiatedBy(NotifyTxFlow::class)
private class InitiateNotifyTxFlow(val otherParty: Party) : FlowLogic<Unit>() {
@Suspendable
override fun call() = subFlow(NotifyTransactionHandler(otherParty))
}
}

View File

@ -7,13 +7,13 @@ import net.corda.contracts.asset.Cash
import net.corda.core.*
import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState
import net.corda.core.identity.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSessionException
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipients
import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.ServiceInfo
@ -30,15 +30,19 @@ import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow
import net.corda.flows.FinalityFlow
import net.corda.flows.NotaryFlow
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.persistence.checkpoints
import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.node.utilities.transaction
import net.corda.testing.*
import net.corda.testing.expect
import net.corda.testing.expectEvents
import net.corda.testing.getTestX509Name
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
@ -110,7 +114,7 @@ class FlowFrameworkTests {
@Test
fun `exception while fiber suspended`() {
node2.registerServiceFlow(ReceiveFlow::class) { SendFlow("Hello", it) }
node2.registerFlowFactory(ReceiveFlow::class) { SendFlow("Hello", it) }
val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
@ -129,7 +133,7 @@ class FlowFrameworkTests {
@Test
fun `flow restarted just after receiving payload`() {
node2.registerServiceFlow(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node2.registerFlowFactory(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity))
// We push through just enough messages to get only the payload sent
@ -179,7 +183,7 @@ class FlowFrameworkTests {
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
node1.registerServiceFlow(ReceiveFlow::class) { SendFlow("Hello", it) }
node1.registerFlowFactory(ReceiveFlow::class) { SendFlow("Hello", it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
@ -198,7 +202,7 @@ class FlowFrameworkTests {
net.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
val node3 = net.createNode(node1.info.address)
val secondFlow = node3.initiateSingleShotFlow(PingPongFlow::class) { PingPongFlow(it, payload2) }
val secondFlow = node3.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
net.runNetwork()
// Kick off first send and receive
@ -243,8 +247,8 @@ class FlowFrameworkTests {
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.registerServiceFlow(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.registerServiceFlow(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node2.registerFlowFactory(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.registerFlowFactory(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = "Hello World"
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork()
@ -277,8 +281,8 @@ class FlowFrameworkTests {
net.runNetwork()
val node2Payload = "Test 1"
val node3Payload = "Test 2"
node2.registerServiceFlow(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.registerServiceFlow(ReceiveFlow::class) { SendFlow(node3Payload, it) }
node2.registerFlowFactory(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.registerFlowFactory(ReceiveFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
node1.services.startFlow(multiReceiveFlow)
node1.acceptableLiveFiberCountOnStop = 1
@ -303,7 +307,7 @@ class FlowFrameworkTests {
@Test
fun `both sides do a send as their first IO request`() {
node2.registerServiceFlow(PingPongFlow::class) { PingPongFlow(it, 20L) }
node2.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) }
node1.services.startFlow(PingPongFlow(node2.info.legalIdentity, 10L))
net.runNetwork()
@ -339,7 +343,7 @@ class FlowFrameworkTests {
sessionTransfers.expectEvents(isStrict = false) {
sequence(
// First Pay
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -349,7 +353,7 @@ class FlowFrameworkTests {
assertEquals(notary1.id, it.from)
},
// Second pay
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -359,7 +363,7 @@ class FlowFrameworkTests {
assertEquals(notary2.id, it.from)
},
// Third pay
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -374,7 +378,7 @@ class FlowFrameworkTests {
@Test
fun `other side ends before doing expected send`() {
node2.registerServiceFlow(ReceiveFlow::class) { NoOpFlow() }
node2.registerFlowFactory(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
@ -384,7 +388,7 @@ class FlowFrameworkTests {
@Test
fun `non-FlowException thrown on other side`() {
val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) {
val erroringFlowFuture = node2.registerFlowFactory(ReceiveFlow::class) {
ExceptionFlow { Exception("evil bug!") }
}
val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps }
@ -418,7 +422,7 @@ class FlowFrameworkTests {
@Test
fun `FlowException thrown on other side`() {
val erroringFlow = node2.initiateSingleShotFlow(ReceiveFlow::class) {
val erroringFlow = node2.registerFlowFactory(ReceiveFlow::class) {
ExceptionFlow { MyFlowException("Nothing useful") }
}
val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps }
@ -456,8 +460,8 @@ class FlowFrameworkTests {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node3.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
node2.initiateSingleShotFlow(ReceiveFlow::class) { ReceiveFlow(node3.info.legalIdentity) }
node3.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
node2.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(node3.info.legalIdentity) }
val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity))
net.runNetwork()
assertThatExceptionOfType(MyFlowException::class.java)
@ -473,9 +477,9 @@ class FlowFrameworkTests {
// Node 2 will send its payload and then block waiting for the receive from node 1. Meanwhile node 1 will move
// onto node 3 which will throw the exception
val node2Fiber = node2
.initiateSingleShotFlow(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.map { it.stateMachine }
node3.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
node3.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
val node1Fiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity)) as FlowStateMachineImpl
net.runNetwork()
@ -528,7 +532,7 @@ class FlowFrameworkTests {
}
}
node2.registerServiceFlow(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") }
node2.registerFlowFactory(AskForExceptionFlow::class) { ConditionalExceptionFlow(it, "Hello") }
val resultFuture = node1.services.startFlow(RetryOnExceptionFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(resultFuture.getOrThrow()).isEqualTo("Hello")
@ -536,7 +540,7 @@ class FlowFrameworkTests {
@Test
fun `serialisation issue in counterparty`() {
node2.registerServiceFlow(ReceiveFlow::class) { SendFlow(NonSerialisableData(1), it) }
node2.registerFlowFactory(ReceiveFlow::class) { SendFlow(NonSerialisableData(1), it) }
val result = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
@ -546,7 +550,7 @@ class FlowFrameworkTests {
@Test
fun `FlowException has non-serialisable object`() {
node2.initiateSingleShotFlow(ReceiveFlow::class) {
node2.registerFlowFactory(ReceiveFlow::class) {
ExceptionFlow { NonSerialisableFlowException(NonSerialisableData(1)) }
}
val result = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
@ -562,9 +566,9 @@ class FlowFrameworkTests {
ptx.addOutputState(DummyState())
val stx = node1.services.signInitialTransaction(ptx)
val committerFiber = node1
.initiateSingleShotFlow(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) }
.map { it.stateMachine }
val committerFiber = node1.registerFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it)
}.map { it.stateMachine }
val waiterStx = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(waiterStx.getOrThrow()).isEqualTo(committerFiber.getOrThrow().resultFuture.getOrThrow())
@ -576,7 +580,7 @@ class FlowFrameworkTests {
ptx.addOutputState(DummyState())
val stx = node1.services.signInitialTransaction(ptx)
node1.registerServiceFlow(WaitingFlows.Waiter::class) {
node1.registerFlowFactory(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) { throw Exception("Error") }
}
val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
@ -594,13 +598,22 @@ class FlowFrameworkTests {
}
@Test
fun `custom client flow`() {
val receiveFlowFuture = node2.initiateSingleShotFlow(SendFlow::class) { ReceiveFlow(it) }
fun `customised client flow`() {
val receiveFlowFuture = node2.registerFlowFactory(SendFlow::class) { ReceiveFlow(it) }
node1.services.startFlow(CustomSendFlow("Hello", node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello")
}
@Test
fun `customised client flow which has annotated @InitiatingFlow again`() {
val result = node1.services.startFlow(IncorrectCustomSendFlow("Hello", node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
result.getOrThrow()
}.withMessageContaining(InitiatingFlow::class.java.simpleName)
}
@Test
fun `upgraded flow`() {
node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity))
@ -612,7 +625,11 @@ class FlowFrameworkTests {
@Test
fun `unsupported new flow version`() {
node2.registerServiceFlow(UpgradedFlow::class, flowVersion = 1) { SendFlow("Hello", it) }
node2.registerFlowFactory(
UpgradedFlow::class.java,
InitiatedFlowFactory.CorDapp(version = 1, factory = ::DoubleInlinedSubFlow),
DoubleInlinedSubFlow::class.java,
track = false)
val result = node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
@ -622,7 +639,7 @@ class FlowFrameworkTests {
@Test
fun `single inlined sub-flow`() {
node2.registerServiceFlow(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) }
node2.registerFlowFactory(SendAndReceiveFlow::class, ::SingleInlinedSubFlow)
val result = node1.services.startFlow(SendAndReceiveFlow(node2.info.legalIdentity, "Hello")).resultFuture
net.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -630,7 +647,7 @@ class FlowFrameworkTests {
@Test
fun `double inlined sub-flow`() {
node2.registerServiceFlow(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) }
node2.registerFlowFactory(SendAndReceiveFlow::class, ::DoubleInlinedSubFlow)
val result = node1.services.startFlow(SendAndReceiveFlow(node2.info.legalIdentity, "Hello")).resultFuture
net.runNetwork()
assertThat(result.getOrThrow()).isEqualTo("HelloHello")
@ -654,6 +671,18 @@ class FlowFrameworkTests {
return smm.findStateMachines(P::class.java).single()
}
private inline fun <reified P : FlowLogic<*>> MockNode.registerFlowFactory(
initiatingFlowClass: KClass<out FlowLogic<*>>,
noinline flowFactory: (Party) -> P): ListenableFuture<P>
{
val observable = registerFlowFactory(initiatingFlowClass.java, object : InitiatedFlowFactory<P> {
override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): P {
return flowFactory(otherParty)
}
}, P::class.java, track = true)
return observable.toFuture()
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java, flowVersion, payload)
}
@ -730,8 +759,12 @@ class FlowFrameworkTests {
}
private interface CustomInterface
private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
object START_STEP : ProgressTracker.Step("Starting")
@ -852,7 +885,7 @@ class FlowFrameworkTests {
}
private data class NonSerialisableData(val a: Int)
private class NonSerialisableFlowException(val data: NonSerialisableData) : FlowException()
private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException()
//endregion Helpers
}