Merge pull request #1803 from corda/mnesbit-CordaServices-startFlows

Start Flows from services
This commit is contained in:
Matthew Nesbit
2017-10-06 13:49:18 +01:00
committed by GitHub
12 changed files with 254 additions and 11 deletions

View File

@ -20,9 +20,8 @@ import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.toX509CertHolder
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.messaging.*
import net.corda.core.node.AppServiceHub
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.*
@ -264,6 +263,49 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
}
}
/**
* This customizes the ServiceHub for each CordaService that is initiating flows
*/
private class AppServiceHubImpl<T : SerializeAsToken>(val serviceHub: ServiceHubInternal): AppServiceHub, ServiceHub by serviceHub {
lateinit var serviceInstance: T
override fun <T> startTrackedFlow(flow: FlowLogic<T>): FlowProgressHandle<T> {
val stateMachine = startFlowChecked(flow)
return FlowProgressHandleImpl(
id = stateMachine.id,
returnValue = stateMachine.resultFuture,
progress = stateMachine.logic.track()?.updates ?: Observable.empty()
)
}
override fun <T> startFlow(flow: FlowLogic<T>): FlowHandle<T> {
val stateMachine = startFlowChecked(flow)
return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture)
}
private fun <T> startFlowChecked(flow: FlowLogic<T>): FlowStateMachineImpl<T> {
val logicType = flow.javaClass
require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" }
val currentUser = FlowInitiator.Service(serviceInstance.javaClass.name)
return serviceHub.startFlow(flow, currentUser)
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is AppServiceHubImpl<*>) return false
if (serviceHub != other.serviceHub) return false
if (serviceInstance != other.serviceInstance) return false
return true
}
override fun hashCode(): Int {
var result = serviceHub.hashCode()
result = 31 * result + serviceInstance.hashCode()
return result
}
}
/**
* Use this method to install your Corda services in your tests. This is automatically done by the node when it
* starts up for all classes it finds which are annotated with [CordaService].
@ -276,8 +318,16 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java, PublicKey::class.java).apply { isAccessible = true }
constructor.newInstance(services, myNotaryIdentity!!.owningKey)
} else {
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true }
constructor.newInstance(services)
try {
val extendedServiceConstructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java).apply { isAccessible = true }
val serviceContext = AppServiceHubImpl<T>(services)
serviceContext.serviceInstance = extendedServiceConstructor.newInstance(serviceContext)
serviceContext.serviceInstance
} catch (ex: NoSuchMethodException) {
val constructor = serviceClass.getDeclaredConstructor(ServiceHub::class.java).apply { isAccessible = true }
log.warn("${serviceClass.name} is using legacy CordaService constructor with ServiceHub parameter. Upgrade to an AppServiceHub parameter to enable updated API features.")
constructor.newInstance(services)
}
}
} catch (e: InvocationTargetException) {
throw ServiceInstantiationException(e.cause)

View File

@ -169,6 +169,7 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
listOf(),
listOf(),
listOf(),
listOf(),
setOf(),
ContractUpgradeFlow.javaClass.protectionDomain.codeSource.location // Core JAR location
)
@ -180,6 +181,7 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
CordappImpl(findContractClassNames(scanResult),
findInitiatedFlows(scanResult),
findRPCFlows(scanResult),
findServiceFlows(scanResult),
findSchedulableFlows(scanResult),
findServices(scanResult),
findPlugins(it),
@ -207,14 +209,18 @@ class CordappLoader private constructor(private val cordappJarPaths: List<URL>)
}
}
private fun findRPCFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return Modifier.isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || Modifier.isStatic(modifiers))
}
private fun Class<out FlowLogic<*>>.isUserInvokable(): Boolean {
return Modifier.isPublic(modifiers) && !isLocalClass && !isAnonymousClass && (!isMemberClass || Modifier.isStatic(modifiers))
}
private fun findRPCFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, StartableByRPC::class).filter { it.isUserInvokable() }
}
private fun findServiceFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, StartableByService::class)
}
private fun findSchedulableFlows(scanResult: ScanResult): List<Class<out FlowLogic<*>>> {
return scanResult.getClassesWithAnnotation(FlowLogic::class, SchedulableFlow::class)
}

View File

@ -111,6 +111,7 @@ class FlowWatchPrintingSubscriber(private val toStream: RenderPrintWriter) : Sub
is FlowInitiator.Shell -> "Shell" // TODO Change when we will have more information on shell user.
is FlowInitiator.Peer -> flowInitiator.party.name.organisation
is FlowInitiator.RPC -> "RPC: " + flowInitiator.username
is FlowInitiator.Service -> "Service: " + flowInitiator.name
}
}

View File

@ -0,0 +1,132 @@
package net.corda.node.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByService
import net.corda.core.node.AppServiceHub
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.CordaService
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.ProgressTracker
import net.corda.finance.DOLLARS
import net.corda.finance.flows.CashIssueFlow
import net.corda.node.internal.cordapp.DummyRPCFlow
import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.nodeapi.internal.ServiceInfo
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.node.MockNetwork
import net.corda.testing.setCordappPackages
import net.corda.testing.unsetCordappPackages
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.*
@StartableByService
class DummyServiceFlow : FlowLogic<FlowInitiator>() {
companion object {
object TEST_STEP : ProgressTracker.Step("Custom progress step")
}
override val progressTracker: ProgressTracker = ProgressTracker(TEST_STEP)
@Suspendable
override fun call(): FlowInitiator {
// We call a subFlow, otehrwise there is no chance to subscribe to the ProgressTracker
subFlow(CashIssueFlow(100.DOLLARS, OpaqueBytes.of(1), serviceHub.networkMapCache.notaryIdentities.first()))
progressTracker.currentStep = TEST_STEP
return stateMachine.flowInitiator
}
}
@CordaService
class TestCordaService(val appServiceHub: AppServiceHub): SingletonSerializeAsToken() {
fun startServiceFlow() {
val handle = appServiceHub.startFlow(DummyServiceFlow())
val initiator = handle.returnValue.get()
initiator as FlowInitiator.Service
assertEquals(this.javaClass.name, initiator.serviceClassName)
}
fun startServiceFlowAndTrack() {
val handle = appServiceHub.startTrackedFlow(DummyServiceFlow())
val count = AtomicInteger(0)
val subscriber = handle.progress.subscribe { count.incrementAndGet() }
handle.returnValue.get()
// Simply prove some progress was made.
// The actual number is currently 11, but don't want to hard code an implementation detail.
assertTrue(count.get() > 1)
subscriber.unsubscribe()
}
}
@CordaService
class TestCordaService2(val appServiceHub: AppServiceHub): SingletonSerializeAsToken() {
fun startInvalidRPCFlow() {
val handle = appServiceHub.startFlow(DummyRPCFlow())
handle.returnValue.get()
}
}
@CordaService
class LegacyCordaService(val simpleServiceHub: ServiceHub): SingletonSerializeAsToken() {
}
class CordaServiceTest {
lateinit var mockNet: MockNetwork
lateinit var notaryNode: StartedNode<MockNetwork.MockNode>
lateinit var nodeA: StartedNode<MockNetwork.MockNode>
@Before
fun start() {
setCordappPackages("net.corda.node.internal","net.corda.finance")
mockNet = MockNetwork(threadPerNode = true)
notaryNode = mockNet.createNode(
legalName = DUMMY_NOTARY.name,
advertisedServices = *arrayOf(ServiceInfo(ValidatingNotaryService.type)))
nodeA = mockNet.createNode(notaryNode.network.myAddress)
mockNet.startNodes()
}
@After
fun cleanUp() {
mockNet.stopNodes()
unsetCordappPackages()
}
@Test
fun `Can find distinct services on node`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
val service2 = nodeA.services.cordaService(TestCordaService2::class.java)
val legacyService = nodeA.services.cordaService(LegacyCordaService::class.java)
assertEquals(TestCordaService::class.java, service.javaClass)
assertEquals(TestCordaService2::class.java, service2.javaClass)
assertNotEquals(service.appServiceHub, service2.appServiceHub) // Each gets a customised AppServiceHub
assertEquals(LegacyCordaService::class.java, legacyService.javaClass)
}
@Test
fun `Can start StartableByService flows`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
service.startServiceFlow()
}
@Test
fun `Can't start StartableByRPC flows`() {
val service = nodeA.services.cordaService(TestCordaService2::class.java)
assertFailsWith<IllegalArgumentException> { service.startInvalidRPCFlow() }
}
@Test
fun `Test flow with progress tracking`() {
val service = nodeA.services.cordaService(TestCordaService::class.java)
service.startServiceFlowAndTrack()
}
}

View File

@ -1,5 +1,6 @@
package net.corda.node.internal.cordapp
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.*
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
@ -7,21 +8,25 @@ import java.nio.file.Paths
@InitiatingFlow
class DummyFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@InitiatedBy(DummyFlow::class)
class LoaderTestFlow(unusedSession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@SchedulableFlow
class DummySchedulableFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}
@StartableByRPC
class DummyRPCFlow : FlowLogic<Unit>() {
@Suspendable
override fun call() { }
}