Merge pull request #1803 from corda/mnesbit-CordaServices-startFlows

Start Flows from services
This commit is contained in:
Matthew Nesbit 2017-10-06 13:49:18 +01:00 committed by GitHub
commit 2564b7da1d
12 changed files with 254 additions and 11 deletions

View File

@ -17,6 +17,7 @@ import java.net.URL
* @property contractClassNames List of contracts
* @property initiatedFlows List of initiatable flow classes
* @property rpcFlows List of RPC initiable flows classes
* @property serviceFlows List of [CordaService] initiable flows classes
* @property schedulableFlows List of flows startable by the scheduler
* @property servies List of RPC services
* @property serializationWhitelists List of Corda plugin registries
@ -28,6 +29,7 @@ interface Cordapp {
val contractClassNames: List<String>
val initiatedFlows: List<Class<out FlowLogic<*>>>
val rpcFlows: List<Class<out FlowLogic<*>>>
val serviceFlows: List<Class<out FlowLogic<*>>>
val schedulableFlows: List<Class<out FlowLogic<*>>>
val services: List<Class<out SerializeAsToken>>
val serializationWhitelists: List<SerializationWhitelist>

View File

@ -20,6 +20,10 @@ sealed class FlowInitiator : Principal {
data class Peer(val party: Party) : FlowInitiator() {
override fun getName(): String = party.name.toString()
}
/** Started by a CordaService. */
data class Service(val serviceClassName: String) : FlowInitiator() {
override fun getName(): String = serviceClassName
}
/** Started as scheduled activity. */
data class Scheduled(val scheduledState: ScheduledStateRef) : FlowInitiator() {
override fun getName(): String = "Scheduler"

View File

@ -0,0 +1,12 @@
package net.corda.core.flows
import kotlin.annotation.AnnotationTarget.CLASS
/**
* Any [FlowLogic] which is to be started by the [AppServiceHub] interface from
* within a [CordaService] must have this annotation. If it's missing the
* flow will not be allowed to start and an exception will be thrown.
*/
@Target(CLASS)
@MustBeDocumented
annotation class StartableByService

View File

@ -12,6 +12,7 @@ data class CordappImpl(
override val contractClassNames: List<String>,
override val initiatedFlows: List<Class<out FlowLogic<*>>>,
override val rpcFlows: List<Class<out FlowLogic<*>>>,
override val serviceFlows: List<Class<out FlowLogic<*>>>,
override val schedulableFlows: List<Class<out FlowLogic<*>>>,
override val services: List<Class<out SerializeAsToken>>,
override val serializationWhitelists: List<SerializationWhitelist>,

View File

@ -0,0 +1,30 @@
package net.corda.core.node
import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import rx.Observable
/**
* A [CordaService] annotated class requires a constructor taking a
* single parameter of type [AppServiceHub].
* With the [AppServiceHub] parameter a [CordaService] is able to access to privileged operations.
* In particular such a [CordaService] can initiate and track flows marked with [net.corda.core.flows.StartableByService].
*/
interface AppServiceHub : ServiceHub {
/**
* Start the given flow with the given arguments. [flow] must be annotated
* with [net.corda.core.flows.StartableByService].
* TODO it is assumed here that the flow object has an appropriate classloader.
*/
fun <T> startFlow(flow: FlowLogic<T>): FlowHandle<T>
/**
* Start the given flow with the given arguments, returning an [Observable] with a single observation of the
* result of running the flow. [flow] must be annotated with [net.corda.core.flows.StartableByService].
* TODO it is assumed here that the flow object has an appropriate classloader.
*/
fun <T> startTrackedFlow(flow: FlowLogic<T>): FlowProgressHandle<T>
}

View File

@ -7,7 +7,7 @@ import kotlin.annotation.AnnotationTarget.CLASS
/**
* Annotate any class that needs to be a long-lived service within the node, such as an oracle, with this annotation.
* Such a class needs to have a constructor with a single parameter of type [ServiceHub]. This constructor will be invoked
* Such a class needs to have a constructor with a single parameter of type [AppServiceHub]. This constructor will be invoked
* during node start to initialise the service. The service hub provided can be used to get information about the node
* that may be necessary for the service. Corda services are created as singletons within the node and are available
* to flows via [ServiceHub.cordaService].

View File

@ -20,9 +20,8 @@ import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.toX509CertHolder
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.messaging.*
import net.corda.core.node.AppServiceHub
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.*
@ -264,6 +263,49 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
}
}
/**
* This customizes the ServiceHub for each CordaService that is initiating flows
*/
private class AppServiceHubImpl<T : SerializeAsToken>(val serviceHub: ServiceHubInternal): AppServiceHub, ServiceHub by serviceHub {
lateinit var serviceInstance: T
override fun <T> startTrackedFlow(flow: FlowLogic<T>): FlowProgressHandle<T> {
val stateMachine = startFlowChecked(flow)
return FlowProgressHandleImpl(
id = stateMachine.id,
returnValue = stateMachine.resultFuture,
progress = stateMachine.logic.track()?.updates ?: Observable.empty()
)
}
override fun <T> startFlow(flow: FlowLogic<T>): FlowHandle<T> {
val stateMachine = startFlowChecked(flow)
return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture)
}
private fun <T> startFlowChecked(flow: FlowLogic<T>): FlowStateMachineImpl<T> {
val logicType = flow.javaClass
require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" }
val currentUser = FlowInitiator.Service(serviceInstance.javaClass.name)
return serviceHub.startFlow(flow, currentUser)
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is AppServiceHubImpl<*>) return false
if (serviceHub != other.serviceHub) return false
if (serviceInstance != other.serviceInstance) return false
return true
}
override fun hashCode(): Int {
var result = serviceHub.hashCode()
result = 31 * result + serviceInstance.hashCode()
return result
}
}
/**
* Use this method to install your Corda services in your tests. This is automatically done by the node when it
* starts up for all classes it finds which are annotated with [CordaService].
@ -276,8 +318,16 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java, PublicKey::class.java).apply { isAccessible = true }
constructor.newInstance(services, myNotaryIdentity!!.owningKey)
} else {
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true }
constructor.newInstance(services)
try {
val extendedServiceConstructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java).apply { isAccessible = true }
val serviceContext = AppServiceHubImpl<T>(services)
serviceContext.serviceInstance = extendedServiceConstructor.newInstance(serviceContext)
serviceContext.serviceInstance
} catch (ex: NoSuchMethodException) {
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true }
log.warn("${serviceClass.name} is using legacy CordaService constructor with ServiceHub parameter. Upgrade to an AppServiceHub parameter to enable updated API features.")
constructor.newInstance(services)
}
}
} catch (e: InvocationTargetException) {
throw ServiceInstantiationException(e.cause)

View File

@ -169,6 +169,7 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
listOf(),
listOf(),
listOf(),
listOf(),
setOf(),
ContractUpgradeFlow.javaClass.protectionDomain.codeSource.location // Core JAR location
)
@ -180,6 +181,7 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
CordappImpl(findContractClassNames(scanResult),
findInitiatedFlows(scanResult),
findRPCFlows(scanResult),
findServiceFlows(scanResult),
findSchedulableFlows(scanResult),
findServices(scanResult),
findPlugins(it),
@ -207,14 +209,18 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
}
}
private fun findRPCFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return Modifier.isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || Modifier.isStatic(modifiers))
}
private fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return Modifier.isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || Modifier.isStatic(modifiers))
}
private fun findRPCFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, StartableByRPC::class).filter { it.isUserInvokable() }
}
private fun findServiceFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, StartableByService::class)
}
private fun findSchedulableFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, SchedulableFlow::class)
}

View File

@ -111,6 +111,7 @@ class FlowWatchPrintingSubscriber(private val toStream: RenderPrintWriter) : Sub
is FlowInitiator.Shell -> "Shell" // TODO Change when we will have more information on shell user.
is FlowInitiator.Peer -> flowInitiator.party.name.organisation
is FlowInitiator.RPC -> "RPC: " + flowInitiator.username
is FlowInitiator.Service -> "Service: " + flowInitiator.name
}
}

View File

@ -0,0 +1,132 @@
package net.corda.node.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByService
import net.corda.core.node.AppServiceHub
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.CordaService
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.ProgressTracker
import net.corda.finance.DOLLARS
import net.corda.finance.flows.CashIssueFlow
import net.corda.node.internal.cordapp.DummyRPCFlow
import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.nodeapi.internal.ServiceInfo
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.node.MockNetwork
import net.corda.testing.setCordappPackages
import net.corda.testing.unsetCordappPackages
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.*
@StartableByService
class DummyServiceFlow : FlowLogic<FlowInitiator>() {
companion object {
object TEST_STEP : ProgressTracker.Step("Custom progress step")
}
override val progressTracker: ProgressTracker = ProgressTracker(TEST_STEP)
@Suspendable
override fun call(): FlowInitiator {
// We call a subFlow, otehrwise there is no chance to subscribe to the ProgressTracker
subFlow(CashIssueFlow(100.DOLLARS, OpaqueBytes.of(1), serviceHub.networkMapCache.notaryIdentities.first()))
progressTracker.currentStep = TEST_STEP
return stateMachine.flowInitiator
}
}
@CordaService
class TestCordaService(val appServiceHub: AppServiceHub): SingletonSerializeAsToken() {
fun startServiceFlow() {
val handle = appServiceHub.startFlow(DummyServiceFlow())
val initiator = handle.returnValue.get()
initiator as FlowInitiator.Service
assertEquals(this.javaClass.name, initiator.serviceClassName)
}
fun startServiceFlowAndTrack() {
val handle = appServiceHub.startTrackedFlow(DummyServiceFlow())
val count = AtomicInteger(0)
val subscriber = handle.progress.subscribe { count.incrementAndGet() }
handle.returnValue.get()
// Simply prove some progress was made.
// The actual number is currently 11, but don't want to hard code an implementation detail.
assertTrue(count.get() > 1)
subscriber.unsubscribe()
}
}
@CordaService
class TestCordaService2(val appServiceHub: AppServiceHub): SingletonSerializeAsToken() {
fun startInvalidRPCFlow() {
val handle = appServiceHub.startFlow(DummyRPCFlow())
handle.returnValue.get()
}
}
@CordaService
class LegacyCordaService(val simpleServiceHub: ServiceHub): SingletonSerializeAsToken() {
}
class CordaServiceTest {
lateinit var mockNet: MockNetwork
lateinit var notaryNode: StartedNode<MockNetwork.MockNode>
lateinit var nodeA: StartedNode<MockNetwork.MockNode>
@Before
fun start() {
setCordappPackages("net.corda.node.internal","net.corda.finance")
mockNet = MockNetwork(threadPerNode = true)
notaryNode = mockNet.createNode(
legalName = DUMMY_NOTARY.name,
advertisedServices = *arrayOf(ServiceInfo(ValidatingNotaryService.type)))
nodeA = mockNet.createNode(notaryNode.network.myAddress)
mockNet.startNodes()
}
@After
fun cleanUp() {
mockNet.stopNodes()
unsetCordappPackages()
}
@Test
fun `Can find distinct services on node`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
val service2 = nodeA.services.cordaService(TestCordaService2::class.java)
val legacyService = nodeA.services.cordaService(LegacyCordaService::class.java)
assertEquals(TestCordaService::class.java, service.javaClass)
assertEquals(TestCordaService2::class.java, service2.javaClass)
assertNotEquals(service.appServiceHub, service2.appServiceHub) // Each gets a customised AppServiceHub
assertEquals(LegacyCordaService::class.java, legacyService.javaClass)
}
@Test
fun `Can start StartableByService flows`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
service.startServiceFlow()
}
@Test
fun `Can't start StartableByRPC flows`() {
val service = nodeA.services.cordaService(TestCordaService2::class.java)
assertFailsWith<IllegalArgumentException> { service.startInvalidRPCFlow() }
}
@Test
fun `Test flow with progress tracking`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
service.startServiceFlowAndTrack()
}
}

View File

@ -1,5 +1,6 @@
package net.corda.node.internal.cordapp
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.*
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
@ -7,21 +8,25 @@ import java.nio.file.Paths
@InitiatingFlow
class DummyFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@InitiatedBy(DummyFlow::class)
class LoaderTestFlow(unusedSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@SchedulableFlow
class DummySchedulableFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@StartableByRPC
class DummyRPCFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}

View File

@ -14,7 +14,7 @@ class MockCordappProvider(cordappLoader: CordappLoader) : CordappProviderImpl(co
val cordappRegistry = mutableListOf<Pair<Cordapp, AttachmentId>>()
fun addMockCordapp(contractClassName: ContractClassName, services: ServiceHub) {
val cordapp = CordappImpl(listOf(contractClassName), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), Paths.get(".").toUri().toURL())
val cordapp = CordappImpl(listOf(contractClassName), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), Paths.get(".").toUri().toURL())
if (cordappRegistry.none { it.first.contractClassNames.contains(contractClassName) }) {
cordappRegistry.add(Pair(cordapp, findOrImportAttachment(contractClassName.toByteArray(), services)))
}