Eliminate circular dependency of NodeSchedulerService on ServiceHub. (#1891)

This commit is contained in:
Andrzej Cichocki 2017-10-19 09:26:26 +01:00 committed by GitHub
parent cf3b080d0c
commit b2454c646c
15 changed files with 96 additions and 96 deletions

View File

@ -119,7 +119,7 @@ class ContractUpgradeFlowTest {
return startRpcClient<CordaRPCOps>( return startRpcClient<CordaRPCOps>(
rpcAddress = startRpcServer( rpcAddress = startRpcServer(
rpcUser = user, rpcUser = user,
ops = CordaRPCOpsImpl(node.services, node.smm, node.database) ops = CordaRPCOpsImpl(node.services, node.smm, node.database, node.services)
).get().broker.hostAndPort!!, ).get().broker.hostAndPort!!,
username = user.username, username = user.username,
password = user.password password = user.password

View File

@ -6,7 +6,7 @@ import net.corda.core.utilities.getOrThrow
import net.corda.finance.POUNDS import net.corda.finance.POUNDS
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.issuedBy import net.corda.finance.issuedBy
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import org.junit.After import org.junit.After
@ -17,8 +17,8 @@ import kotlin.test.assertFailsWith
class FinalityFlowTests { class FinalityFlowTests {
private lateinit var mockNet: MockNetwork private lateinit var mockNet: MockNetwork
private lateinit var aliceServices: ServiceHubInternal private lateinit var aliceServices: StartedNodeServices
private lateinit var bobServices: ServiceHubInternal private lateinit var bobServices: StartedNodeServices
private lateinit var alice: Party private lateinit var alice: Party
private lateinit var bob: Party private lateinit var bob: Party
private lateinit var notary: Party private lateinit var notary: Party

View File

@ -32,7 +32,7 @@ class CustomVaultQueryTest {
nodeA = mockNet.createPartyNode() nodeA = mockNet.createPartyNode()
nodeB = mockNet.createPartyNode() nodeB = mockNet.createPartyNode()
nodeA.internals.registerInitiatedFlow(TopupIssuerFlow.TopupIssuer::class.java) nodeA.internals.registerInitiatedFlow(TopupIssuerFlow.TopupIssuer::class.java)
nodeA.internals.installCordaService(CustomVaultQuery.Service::class.java) nodeA.installCordaService(CustomVaultQuery.Service::class.java)
notary = nodeA.services.getDefaultNotary() notary = nodeA.services.getDefaultNotary()
} }

View File

@ -9,7 +9,7 @@ import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import org.junit.After import org.junit.After
@ -19,8 +19,8 @@ import kotlin.test.assertEquals
class WorkflowTransactionBuildTutorialTest { class WorkflowTransactionBuildTutorialTest {
lateinit var mockNet: MockNetwork lateinit var mockNet: MockNetwork
lateinit var aliceServices: ServiceHubInternal lateinit var aliceServices: StartedNodeServices
lateinit var bobServices: ServiceHubInternal lateinit var bobServices: StartedNodeServices
lateinit var alice: Party lateinit var alice: Party
lateinit var bob: Party lateinit var bob: Party

View File

@ -78,6 +78,7 @@ import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.sql.Connection import java.sql.Connection
import java.time.Clock import java.time.Clock
import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.TimeUnit.SECONDS
@ -108,7 +109,7 @@ abstract class AbstractNode(config: NodeConfiguration,
private class StartedNodeImpl<out N : AbstractNode>( private class StartedNodeImpl<out N : AbstractNode>(
override val internals: N, override val internals: N,
override val services: ServiceHubInternalImpl, services: ServiceHubInternalImpl,
override val info: NodeInfo, override val info: NodeInfo,
override val checkpointStorage: CheckpointStorage, override val checkpointStorage: CheckpointStorage,
override val smm: StateMachineManager, override val smm: StateMachineManager,
@ -116,8 +117,11 @@ abstract class AbstractNode(config: NodeConfiguration,
override val inNodeNetworkMapService: NetworkMapService, override val inNodeNetworkMapService: NetworkMapService,
override val network: MessagingService, override val network: MessagingService,
override val database: CordaPersistence, override val database: CordaPersistence,
override val rpcOps: CordaRPCOps) : StartedNode<N> override val rpcOps: CordaRPCOps,
flowStarter: FlowStarter,
internal val schedulerService: NodeSchedulerService) : StartedNode<N> {
override val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter {}
}
// TODO: Persist this, as well as whether the node is registered. // TODO: Persist this, as well as whether the node is registered.
/** /**
* Sequence number of changes sent to the network map service, when registering/de-registering this node. * Sequence number of changes sent to the network map service, when registering/de-registering this node.
@ -167,8 +171,8 @@ abstract class AbstractNode(config: NodeConfiguration,
@Volatile private var _started: StartedNode<AbstractNode>? = null @Volatile private var _started: StartedNode<AbstractNode>? = null
/** The implementation of the [CordaRPCOps] interface used by this node. */ /** The implementation of the [CordaRPCOps] interface used by this node. */
open fun makeRPCOps(): CordaRPCOps { open fun makeRPCOps(flowStarter: FlowStarter): CordaRPCOps {
return CordaRPCOpsImpl(services, smm, database) return CordaRPCOpsImpl(services, smm, database, flowStarter)
} }
private fun saveOwnNodeInfo() { private fun saveOwnNodeInfo() {
@ -190,7 +194,8 @@ abstract class AbstractNode(config: NodeConfiguration,
log.info("Generating nodeInfo ...") log.info("Generating nodeInfo ...")
val schemaService = makeSchemaService() val schemaService = makeSchemaService()
initialiseDatabasePersistence(schemaService) { initialiseDatabasePersistence(schemaService) {
makeServices(schemaService) val transactionStorage = makeTransactionStorage()
makeServices(schemaService, transactionStorage, StateLoaderImpl(transactionStorage))
saveOwnNodeInfo() saveOwnNodeInfo()
} }
} }
@ -202,7 +207,9 @@ abstract class AbstractNode(config: NodeConfiguration,
val schemaService = makeSchemaService() val schemaService = makeSchemaService()
// Do all of this in a database transaction so anything that might need a connection has one. // Do all of this in a database transaction so anything that might need a connection has one.
val startedImpl = initialiseDatabasePersistence(schemaService) { val startedImpl = initialiseDatabasePersistence(schemaService) {
val tokenizableServices = makeServices(schemaService) val transactionStorage = makeTransactionStorage()
val stateLoader = StateLoaderImpl(transactionStorage)
val tokenizableServices = makeServices(schemaService, transactionStorage, stateLoader)
saveOwnNodeInfo() saveOwnNodeInfo()
smm = StateMachineManager(services, smm = StateMachineManager(services,
checkpointStorage, checkpointStorage,
@ -210,9 +217,10 @@ abstract class AbstractNode(config: NodeConfiguration,
database, database,
busyNodeLatch, busyNodeLatch,
cordappLoader.appClassLoader) cordappLoader.appClassLoader)
val flowStarter = FlowStarterImpl(serverThread, smm)
val schedulerService = NodeSchedulerService(platformClock, this@AbstractNode.database, flowStarter, stateLoader, unfinishedSchedules = busyNodeLatch, serverThread = serverThread)
smm.tokenizableServices.addAll(tokenizableServices) smm.tokenizableServices.addAll(tokenizableServices)
smm.tokenizableServices.add(schedulerService)
if (serverThread is ExecutorService) { if (serverThread is ExecutorService) {
runOnStop += { runOnStop += {
// We wait here, even though any in-flight messages should have been drained away because the // We wait here, even though any in-flight messages should have been drained away because the
@ -221,20 +229,17 @@ abstract class AbstractNode(config: NodeConfiguration,
MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, SECONDS) MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, SECONDS)
} }
} }
makeVaultObservers(schedulerService)
makeVaultObservers() val rpcOps = makeRPCOps(flowStarter)
val rpcOps = makeRPCOps()
startMessagingService(rpcOps) startMessagingService(rpcOps)
installCoreFlows() installCoreFlows()
installCordaServices(flowStarter)
installCordaServices()
registerCordappFlows() registerCordappFlows()
_services.rpcFlows += cordappLoader.cordapps.flatMap { it.rpcFlows } _services.rpcFlows += cordappLoader.cordapps.flatMap { it.rpcFlows }
FlowLogicRefFactoryImpl.classloader = cordappLoader.appClassLoader FlowLogicRefFactoryImpl.classloader = cordappLoader.appClassLoader
runOnStop += network::stop runOnStop += network::stop
StartedNodeImpl(this, _services, info, checkpointStorage, smm, attachments, inNodeNetworkMapService, network, database, rpcOps) StartedNodeImpl(this, _services, info, checkpointStorage, smm, attachments, inNodeNetworkMapService, network, database, rpcOps, flowStarter, schedulerService)
} }
// If we successfully loaded network data from database, we set this future to Unit. // If we successfully loaded network data from database, we set this future to Unit.
_nodeReadyFuture.captureLater(registerWithNetworkMapIfConfigured()) _nodeReadyFuture.captureLater(registerWithNetworkMapIfConfigured())
@ -243,7 +248,7 @@ abstract class AbstractNode(config: NodeConfiguration,
smm.start() smm.start()
// Shut down the SMM so no Fibers are scheduled. // Shut down the SMM so no Fibers are scheduled.
runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) } runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) }
services.schedulerService.start() schedulerService.start()
} }
_started = this _started = this
} }
@ -251,11 +256,11 @@ abstract class AbstractNode(config: NodeConfiguration,
private class ServiceInstantiationException(cause: Throwable?) : CordaException("Service Instantiation Error", cause) private class ServiceInstantiationException(cause: Throwable?) : CordaException("Service Instantiation Error", cause)
private fun installCordaServices() { private fun installCordaServices(flowStarter: FlowStarter) {
val loadedServices = cordappLoader.cordapps.flatMap { it.services } val loadedServices = cordappLoader.cordapps.flatMap { it.services }
filterServicesToInstall(loadedServices).forEach { filterServicesToInstall(loadedServices).forEach {
try { try {
installCordaService(it) installCordaService(flowStarter, it)
} catch (e: NoSuchMethodException) { } catch (e: NoSuchMethodException) {
log.error("${it.name}, as a Corda service, must have a constructor with a single parameter of type " + log.error("${it.name}, as a Corda service, must have a constructor with a single parameter of type " +
ServiceHub::class.java.name) ServiceHub::class.java.name)
@ -288,7 +293,7 @@ abstract class AbstractNode(config: NodeConfiguration,
/** /**
* This customizes the ServiceHub for each CordaService that is initiating flows * This customizes the ServiceHub for each CordaService that is initiating flows
*/ */
private class AppServiceHubImpl<T : SerializeAsToken>(val serviceHub: ServiceHubInternal) : AppServiceHub, ServiceHub by serviceHub { private class AppServiceHubImpl<T : SerializeAsToken>(private val serviceHub: ServiceHub, private val flowStarter: FlowStarter) : AppServiceHub, ServiceHub by serviceHub {
lateinit var serviceInstance: T lateinit var serviceInstance: T
override fun <T> startTrackedFlow(flow: FlowLogic<T>): FlowProgressHandle<T> { override fun <T> startTrackedFlow(flow: FlowLogic<T>): FlowProgressHandle<T> {
val stateMachine = startFlowChecked(flow) val stateMachine = startFlowChecked(flow)
@ -308,34 +313,24 @@ abstract class AbstractNode(config: NodeConfiguration,
val logicType = flow.javaClass val logicType = flow.javaClass
require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" } require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" }
val currentUser = FlowInitiator.Service(serviceInstance.javaClass.name) val currentUser = FlowInitiator.Service(serviceInstance.javaClass.name)
return serviceHub.startFlow(flow, currentUser) return flowStarter.startFlow(flow, currentUser)
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is AppServiceHubImpl<*>) return false if (other !is AppServiceHubImpl<*>) return false
return serviceHub == other.serviceHub
if (serviceHub != other.serviceHub) return false && flowStarter == other.flowStarter
if (serviceInstance != other.serviceInstance) return false && serviceInstance == other.serviceInstance
return true
} }
override fun hashCode(): Int { override fun hashCode() = Objects.hash(serviceHub, flowStarter, serviceInstance)
var result = serviceHub.hashCode()
result = 31 * result + serviceInstance.hashCode()
return result
}
} }
/** internal fun <T : SerializeAsToken> installCordaService(flowStarter: FlowStarter, serviceClass: Class<T>): T {
* Use this method to install your Corda services in your tests. This is automatically done by the node when it
* starts up for all classes it finds which are annotated with [CordaService].
*/
fun <T : SerializeAsToken> installCordaService(serviceClass: Class<T>): T {
serviceClass.requireAnnotation<CordaService>() serviceClass.requireAnnotation<CordaService>()
val service = try { val service = try {
val serviceContext = AppServiceHubImpl<T>(services) val serviceContext = AppServiceHubImpl<T>(services, flowStarter)
if (isNotaryService(serviceClass)) { if (isNotaryService(serviceClass)) {
check(myNotaryIdentity != null) { "Trying to install a notary service but no notary identity specified" } check(myNotaryIdentity != null) { "Trying to install a notary service but no notary identity specified" }
val constructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java, PublicKey::class.java).apply { isAccessible = true } val constructor = serviceClass.getDeclaredConstructor(AppServiceHub::class.java, PublicKey::class.java).apply { isAccessible = true }
@ -466,19 +461,18 @@ abstract class AbstractNode(config: NodeConfiguration,
* Builds node internal, advertised, and plugin services. * Builds node internal, advertised, and plugin services.
* Returns a list of tokenizable services to be added to the serialisation context. * Returns a list of tokenizable services to be added to the serialisation context.
*/ */
private fun makeServices(schemaService: SchemaService): MutableList<Any> { private fun makeServices(schemaService: SchemaService, transactionStorage: WritableTransactionStorage, stateLoader: StateLoader): MutableList<Any> {
checkpointStorage = DBCheckpointStorage() checkpointStorage = DBCheckpointStorage()
val transactionStorage = makeTransactionStorage()
val metrics = MetricRegistry() val metrics = MetricRegistry()
attachments = NodeAttachmentService(metrics) attachments = NodeAttachmentService(metrics)
val cordappProvider = CordappProviderImpl(cordappLoader, attachments) val cordappProvider = CordappProviderImpl(cordappLoader, attachments)
_services = ServiceHubInternalImpl(schemaService, transactionStorage, StateLoaderImpl(transactionStorage), MonitoringService(metrics), cordappProvider) _services = ServiceHubInternalImpl(schemaService, transactionStorage, stateLoader, MonitoringService(metrics), cordappProvider)
legalIdentity = obtainIdentity(notaryConfig = null) legalIdentity = obtainIdentity(notaryConfig = null)
network = makeMessagingService(legalIdentity) network = makeMessagingService(legalIdentity)
info = makeInfo(legalIdentity) info = makeInfo(legalIdentity)
val networkMapCache = services.networkMapCache val networkMapCache = services.networkMapCache
val tokenizableServices = mutableListOf(attachments, network, services.vaultService, val tokenizableServices = mutableListOf(attachments, network, services.vaultService,
services.keyManagementService, services.identityService, platformClock, services.schedulerService, services.keyManagementService, services.identityService, platformClock,
services.auditService, services.monitoringService, networkMapCache, services.schemaService, services.auditService, services.monitoringService, networkMapCache, services.schemaService,
services.transactionVerifierService, services.validatedTransactions, services.contractUpgradeService, services.transactionVerifierService, services.validatedTransactions, services.contractUpgradeService,
services, cordappProvider, this) services, cordappProvider, this)
@ -488,9 +482,9 @@ abstract class AbstractNode(config: NodeConfiguration,
protected open fun makeTransactionStorage(): WritableTransactionStorage = DBTransactionStorage() protected open fun makeTransactionStorage(): WritableTransactionStorage = DBTransactionStorage()
private fun makeVaultObservers() { private fun makeVaultObservers(schedulerService: SchedulerService) {
VaultSoftLockManager.install(services.vaultService, smm) VaultSoftLockManager.install(services.vaultService, smm)
ScheduledActivityObserver.install(services.vaultService, services.schedulerService) ScheduledActivityObserver.install(services.vaultService, schedulerService)
HibernateObserver.install(services.vaultService.rawUpdates, database.hibernateConfig) HibernateObserver.install(services.vaultService.rawUpdates, database.hibernateConfig)
} }
@ -788,7 +782,6 @@ abstract class AbstractNode(config: NodeConfiguration,
// the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with
// the identity key. But the infrastructure to make that easy isn't here yet. // the identity key. But the infrastructure to make that easy isn't here yet.
override val keyManagementService by lazy { makeKeyManagementService(identityService) } override val keyManagementService by lazy { makeKeyManagementService(identityService) }
override val schedulerService by lazy { NodeSchedulerService(this, unfinishedSchedules = busyNodeLatch, serverThread = serverThread) }
override val identityService by lazy { override val identityService by lazy {
val trustStore = KeyStoreWrapper(configuration.trustStoreFile, configuration.trustStorePassword) val trustStore = KeyStoreWrapper(configuration.trustStoreFile, configuration.trustStorePassword)
val caKeyStore = KeyStoreWrapper(configuration.nodeKeystore, configuration.keyStorePassword) val caKeyStore = KeyStoreWrapper(configuration.nodeKeystore, configuration.keyStorePassword)
@ -808,10 +801,6 @@ abstract class AbstractNode(config: NodeConfiguration,
return cordappServices.getInstance(type) ?: throw IllegalArgumentException("Corda service ${type.name} does not exist") return cordappServices.getInstance(type) ?: throw IllegalArgumentException("Corda service ${type.name} does not exist")
} }
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator, ourIdentity: Party?): FlowStateMachineImpl<T> {
return serverThread.fetchFrom { smm.add(logic, flowInitiator, ourIdentity) }
}
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? { override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
return flowFactories[initiatingFlowClass] return flowFactories[initiatingFlowClass]
} }
@ -825,3 +814,9 @@ abstract class AbstractNode(config: NodeConfiguration,
override fun jdbcSession(): Connection = database.createSession() override fun jdbcSession(): Connection = database.createSession()
} }
} }
internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager) : FlowStarter {
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator, ourIdentity: Party?): FlowStateMachineImpl<T> {
return serverThread.fetchFrom { smm.add(logic, flowInitiator, ourIdentity) }
}
}

View File

@ -19,6 +19,7 @@ import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort import net.corda.core.node.services.vault.Sort
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.FlowPermissions.Companion.startFlowPermission import net.corda.node.services.FlowPermissions.Companion.startFlowPermission
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.getRpcContext import net.corda.node.services.messaging.getRpcContext
import net.corda.node.services.messaging.requirePermission import net.corda.node.services.messaging.requirePermission
@ -37,7 +38,8 @@ import java.time.Instant
class CordaRPCOpsImpl( class CordaRPCOpsImpl(
private val services: ServiceHubInternal, private val services: ServiceHubInternal,
private val smm: StateMachineManager, private val smm: StateMachineManager,
private val database: CordaPersistence private val database: CordaPersistence,
private val flowStarter: FlowStarter
) : CordaRPCOps { ) : CordaRPCOps {
override fun networkMapSnapshot(): List<NodeInfo> { override fun networkMapSnapshot(): List<NodeInfo> {
val (snapshot, updates) = networkMapFeed() val (snapshot, updates) = networkMapFeed()
@ -150,7 +152,7 @@ class CordaRPCOpsImpl(
rpcContext.requirePermission(startFlowPermission(logicType)) rpcContext.requirePermission(startFlowPermission(logicType))
val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username) val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username)
// TODO RPC flows should have mapping user -> identity that should be resolved automatically on starting flow. // TODO RPC flows should have mapping user -> identity that should be resolved automatically on starting flow.
return services.invokeFlowAsync(logicType, currentUser, *args) return flowStarter.invokeFlowAsync(logicType, currentUser, *args)
} }
override fun attachmentExists(id: SecureHash): Boolean { override fun attachmentExists(id: SecureHash): Boolean {

View File

@ -7,9 +7,11 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader import net.corda.core.node.StateLoader
import net.corda.core.node.services.CordaService
import net.corda.core.node.services.TransactionStorage import net.corda.core.node.services.TransactionStorage
import net.corda.core.serialization.SerializeAsToken
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.network.NetworkMapService import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.NodeAttachmentService
@ -18,7 +20,7 @@ import net.corda.node.utilities.CordaPersistence
interface StartedNode<out N : AbstractNode> { interface StartedNode<out N : AbstractNode> {
val internals: N val internals: N
val services: ServiceHubInternal val services: StartedNodeServices
val info: NodeInfo val info: NodeInfo
val checkpointStorage: CheckpointStorage val checkpointStorage: CheckpointStorage
val smm: StateMachineManager val smm: StateMachineManager
@ -29,6 +31,11 @@ interface StartedNode<out N : AbstractNode> {
val rpcOps: CordaRPCOps val rpcOps: CordaRPCOps
fun dispose() = internals.stop() fun dispose() = internals.stop()
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = internals.registerInitiatedFlow(initiatedFlowClass) fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = internals.registerInitiatedFlow(initiatedFlowClass)
/**
* Use this method to install your Corda services in your tests. This is automatically done by the node when it
* starts up for all classes it finds which are annotated with [CordaService].
*/
fun <T : SerializeAsToken> installCordaService(serviceClass: Class<T>) = internals.installCordaService(services, serviceClass)
} }
class StateLoaderImpl(private val validatedTransactions: TransactionStorage) : StateLoader { class StateLoaderImpl(private val validatedTransactions: TransactionStorage) : StateLoader {

View File

@ -84,7 +84,6 @@ interface ServiceHubInternal : ServiceHub {
val monitoringService: MonitoringService val monitoringService: MonitoringService
val schemaService: SchemaService val schemaService: SchemaService
override val networkMapCache: NetworkMapCacheInternal override val networkMapCache: NetworkMapCacheInternal
val schedulerService: SchedulerService
val auditService: AuditService val auditService: AuditService
val rpcFlows: List<Class<out FlowLogic<*>>> val rpcFlows: List<Class<out FlowLogic<*>>>
val networkService: MessagingService val networkService: MessagingService
@ -109,6 +108,10 @@ interface ServiceHubInternal : ServiceHub {
} }
} }
fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>?
}
interface FlowStarter {
/** /**
* Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator] * Starts an already constructed flow. Note that you must be on the server thread to call this method. [FlowInitiator]
* defaults to [FlowInitiator.RPC] with username "Only For Testing". * defaults to [FlowInitiator.RPC] with username "Only For Testing".
@ -138,10 +141,9 @@ interface ServiceHubInternal : ServiceHub {
val logic: FlowLogic<T> = uncheckedCast(FlowLogicRefFactoryImpl.toFlowLogic(logicRef)) val logic: FlowLogic<T> = uncheckedCast(FlowLogicRefFactoryImpl.toFlowLogic(logicRef))
return startFlow(logic, flowInitiator, ourIdentity = null) return startFlow(logic, flowInitiator, ourIdentity = null)
} }
fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>?
} }
interface StartedNodeServices : ServiceHubInternal, FlowStarter
/** /**
* Thread-safe storage of transactions. * Thread-safe storage of transactions.
*/ */

View File

@ -14,15 +14,17 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.until import net.corda.core.internal.until
import net.corda.core.node.StateLoader
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.internal.MutableClock import net.corda.node.internal.MutableClock
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.SchedulerService import net.corda.node.services.api.SchedulerService
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.NODE_DATABASE_PREFIX import net.corda.node.utilities.NODE_DATABASE_PREFIX
import net.corda.node.utilities.PersistentMap import net.corda.node.utilities.PersistentMap
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
@ -47,12 +49,14 @@ import javax.persistence.Entity
* in the nodes, maybe we can consider multiple activities and whether the activities have been completed or not, * in the nodes, maybe we can consider multiple activities and whether the activities have been completed or not,
* but that starts to sound a lot like off-ledger state. * but that starts to sound a lot like off-ledger state.
* *
* @param services Core node services.
* @param schedulerTimerExecutor The executor the scheduler blocks on waiting for the clock to advance to the next * @param schedulerTimerExecutor The executor the scheduler blocks on waiting for the clock to advance to the next
* activity. Only replace this for unit testing purposes. This is not the executor the [FlowLogic] is launched on. * activity. Only replace this for unit testing purposes. This is not the executor the [FlowLogic] is launched on.
*/ */
@ThreadSafe @ThreadSafe
class NodeSchedulerService(private val services: ServiceHubInternal, class NodeSchedulerService(private val clock: Clock,
private val database: CordaPersistence,
private val flowStarter: FlowStarter,
private val stateLoader: StateLoader,
private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor(), private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor(),
private val unfinishedSchedules: ReusableLatch = ReusableLatch(), private val unfinishedSchedules: ReusableLatch = ReusableLatch(),
private val serverThread: AffinityExecutor) private val serverThread: AffinityExecutor)
@ -108,8 +112,8 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = { fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable //TODO null check will become obsolete after making DB/JPA columns not nullable
var txId = it.output.txId ?: throw IllegalStateException("DB returned null SecureHash transactionId") val txId = it.output.txId ?: throw IllegalStateException("DB returned null SecureHash transactionId")
var index = it.output.index ?: throw IllegalStateException("DB returned null SecureHash index") val index = it.output.index ?: throw IllegalStateException("DB returned null SecureHash index")
Pair(StateRef(SecureHash.parse(txId), index), Pair(StateRef(SecureHash.parse(txId), index),
ScheduledStateRef(StateRef(SecureHash.parse(txId), index), it.scheduledAt)) ScheduledStateRef(StateRef(SecureHash.parse(txId), index), it.scheduledAt))
}, },
@ -172,7 +176,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
mutex.locked { mutex.locked {
val previousState = scheduledStates[action.ref] val previousState = scheduledStates[action.ref]
scheduledStates[action.ref] = action scheduledStates[action.ref] = action
var previousEarliest = scheduledStatesQueue.peek() val previousEarliest = scheduledStatesQueue.peek()
scheduledStatesQueue.remove(previousState) scheduledStatesQueue.remove(previousState)
scheduledStatesQueue.add(action) scheduledStatesQueue.add(action)
if (previousState == null) { if (previousState == null) {
@ -223,7 +227,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
log.trace { "Scheduling as next $scheduledState" } log.trace { "Scheduling as next $scheduledState" }
// This will block the scheduler single thread until the scheduled time (returns false) OR // This will block the scheduler single thread until the scheduled time (returns false) OR
// the Future is cancelled due to rescheduling (returns true). // the Future is cancelled due to rescheduling (returns true).
if (!awaitWithDeadline(services.clock, scheduledState.scheduledAt, ourRescheduledFuture)) { if (!awaitWithDeadline(clock, scheduledState.scheduledAt, ourRescheduledFuture)) {
log.trace { "Invoking as next $scheduledState" } log.trace { "Invoking as next $scheduledState" }
onTimeReached(scheduledState) onTimeReached(scheduledState)
} else { } else {
@ -237,11 +241,11 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
serverThread.execute { serverThread.execute {
var flowName: String? = "(unknown)" var flowName: String? = "(unknown)"
try { try {
services.database.transaction { database.transaction {
val scheduledFlow = getScheduledFlow(scheduledState) val scheduledFlow = getScheduledFlow(scheduledState)
if (scheduledFlow != null) { if (scheduledFlow != null) {
flowName = scheduledFlow.javaClass.name flowName = scheduledFlow.javaClass.name
val future = services.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).resultFuture val future = flowStarter.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).resultFuture
future.then { future.then {
unfinishedSchedules.countDown() unfinishedSchedules.countDown()
} }
@ -265,9 +269,9 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
unfinishedSchedules.countDown() unfinishedSchedules.countDown()
scheduledStates.remove(scheduledState.ref) scheduledStates.remove(scheduledState.ref)
scheduledStatesQueue.remove(scheduledState) scheduledStatesQueue.remove(scheduledState)
} else if (scheduledActivity.scheduledAt.isAfter(services.clock.instant())) { } else if (scheduledActivity.scheduledAt.isAfter(clock.instant())) {
log.info("Scheduled state $scheduledState has rescheduled to ${scheduledActivity.scheduledAt}.") log.info("Scheduled state $scheduledState has rescheduled to ${scheduledActivity.scheduledAt}.")
var newState = ScheduledStateRef(scheduledState.ref, scheduledActivity.scheduledAt) val newState = ScheduledStateRef(scheduledState.ref, scheduledActivity.scheduledAt)
scheduledStates[scheduledState.ref] = newState scheduledStates[scheduledState.ref] = newState
scheduledStatesQueue.remove(scheduledState) scheduledStatesQueue.remove(scheduledState)
scheduledStatesQueue.add(newState) scheduledStatesQueue.add(newState)
@ -286,7 +290,7 @@ class NodeSchedulerService(private val services: ServiceHubInternal,
} }
private fun getScheduledActivity(scheduledState: ScheduledStateRef): ScheduledActivity? { private fun getScheduledActivity(scheduledState: ScheduledStateRef): ScheduledActivity? {
val txState = services.loadState(scheduledState.ref) val txState = stateLoader.loadState(scheduledState.ref)
val state = txState.data as SchedulableState val state = txState.data as SchedulableState
return try { return try {
// This can throw as running contract code. // This can throw as running contract code.

View File

@ -65,7 +65,7 @@ class CordaRPCOpsImplTest {
mockNet = MockNetwork(cordappPackages = listOf("net.corda.finance.contracts.asset")) mockNet = MockNetwork(cordappPackages = listOf("net.corda.finance.contracts.asset"))
aliceNode = mockNet.createNode() aliceNode = mockNet.createNode()
notaryNode = mockNet.createNotaryNode(validating = false) notaryNode = mockNet.createNotaryNode(validating = false)
rpc = CordaRPCOpsImpl(aliceNode.services, aliceNode.smm, aliceNode.database) rpc = CordaRPCOpsImpl(aliceNode.services, aliceNode.smm, aliceNode.database, aliceNode.services)
CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = setOf( CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = setOf(
startFlowPermission<CashIssueFlow>(), startFlowPermission<CashIssueFlow>(),
startFlowPermission<CashPaymentFlow>() startFlowPermission<CashPaymentFlow>()

View File

@ -12,6 +12,7 @@ import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.node.internal.FlowStarterImpl
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.api.VaultServiceInternal
@ -100,8 +101,8 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
override val cordappProvider = CordappProviderImpl(CordappLoader.createWithTestPackages(listOf("net.corda.testing.contracts")), attachments) override val cordappProvider = CordappProviderImpl(CordappLoader.createWithTestPackages(listOf("net.corda.testing.contracts")), attachments)
} }
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
scheduler = NodeSchedulerService(services, schedulerGatedExecutor, serverThread = smmExecutor)
val mockSMM = StateMachineManager(services, DBCheckpointStorage(), smmExecutor, database) val mockSMM = StateMachineManager(services, DBCheckpointStorage(), smmExecutor, database)
scheduler = NodeSchedulerService(testClock, database, FlowStarterImpl(smmExecutor, mockSMM), services.stateLoader, schedulerGatedExecutor, serverThread = smmExecutor)
mockSMM.changes.subscribe { change -> mockSMM.changes.subscribe { change ->
if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) { if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) {
smmHasRemovedAllFlows.countDown() smmHasRemovedAllFlows.countDown()

View File

@ -13,7 +13,7 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
@ -28,8 +28,8 @@ import kotlin.test.assertFailsWith
class NotaryServiceTests { class NotaryServiceTests {
lateinit var mockNet: MockNetwork lateinit var mockNet: MockNetwork
lateinit var notaryServices: ServiceHubInternal lateinit var notaryServices: StartedNodeServices
lateinit var aliceServices: ServiceHubInternal lateinit var aliceServices: StartedNodeServices
lateinit var notary: Party lateinit var notary: Party
lateinit var alice: Party lateinit var alice: Party

View File

@ -13,7 +13,7 @@ import net.corda.core.node.ServiceHub
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.StartedNodeServices
import net.corda.node.services.issueInvalidState import net.corda.node.services.issueInvalidState
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
@ -28,8 +28,8 @@ import kotlin.test.assertFailsWith
class ValidatingNotaryServiceTests { class ValidatingNotaryServiceTests {
lateinit var mockNet: MockNetwork lateinit var mockNet: MockNetwork
lateinit var notaryServices: ServiceHubInternal lateinit var notaryServices: StartedNodeServices
lateinit var aliceServices: ServiceHubInternal lateinit var aliceServices: StartedNodeServices
lateinit var notary: Party lateinit var notary: Party
lateinit var alice: Party lateinit var alice: Party

View File

@ -206,7 +206,7 @@ class NodeInterestRatesTest : TestDependencyInjectionBase() {
internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java)
internals.registerInitiatedFlow(NodeInterestRates.FixSignHandler::class.java) internals.registerInitiatedFlow(NodeInterestRates.FixSignHandler::class.java)
database.transaction { database.transaction {
internals.installCordaService(NodeInterestRates.Oracle::class.java).knownFixes = TEST_DATA installCordaService(NodeInterestRates.Oracle::class.java).knownFixes = TEST_DATA
} }
} }
val tx = makePartialTX() val tx = makePartialTX()

View File

@ -1,9 +1,7 @@
package net.corda.node.testing package net.corda.node.testing
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader import net.corda.core.node.StateLoader
import net.corda.core.node.services.* import net.corda.core.node.services.*
@ -17,7 +15,6 @@ import net.corda.node.serialization.NodeClock
import net.corda.node.services.api.* import net.corda.node.services.api.*
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.transactions.InMemoryTransactionVerifierService
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
@ -43,12 +40,11 @@ open class MockServiceHubInternal(
override val validatedTransactions: WritableTransactionStorage = MockTransactionStorage(), override val validatedTransactions: WritableTransactionStorage = MockTransactionStorage(),
override val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage(), override val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage(),
val mapCache: NetworkMapCacheInternal? = null, val mapCache: NetworkMapCacheInternal? = null,
val scheduler: SchedulerService? = null,
val overrideClock: Clock? = NodeClock(), val overrideClock: Clock? = NodeClock(),
val customContractUpgradeService: ContractUpgradeService? = null, val customContractUpgradeService: ContractUpgradeService? = null,
val customTransactionVerifierService: TransactionVerifierService? = InMemoryTransactionVerifierService(2), val customTransactionVerifierService: TransactionVerifierService? = InMemoryTransactionVerifierService(2),
override val cordappProvider: CordappProviderInternal = CordappProviderImpl(CordappLoader.createDefault(Paths.get(".")), attachments), override val cordappProvider: CordappProviderInternal = CordappProviderImpl(CordappLoader.createDefault(Paths.get(".")), attachments),
protected val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions) val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions)
) : ServiceHubInternal, StateLoader by stateLoader { ) : ServiceHubInternal, StateLoader by stateLoader {
override val transactionVerifierService: TransactionVerifierService override val transactionVerifierService: TransactionVerifierService
get() = customTransactionVerifierService ?: throw UnsupportedOperationException() get() = customTransactionVerifierService ?: throw UnsupportedOperationException()
@ -64,8 +60,6 @@ open class MockServiceHubInternal(
get() = network ?: throw UnsupportedOperationException() get() = network ?: throw UnsupportedOperationException()
override val networkMapCache: NetworkMapCacheInternal override val networkMapCache: NetworkMapCacheInternal
get() = mapCache ?: MockNetworkMapCache(this) get() = mapCache ?: MockNetworkMapCache(this)
override val schedulerService: SchedulerService
get() = scheduler ?: throw UnsupportedOperationException()
override val clock: Clock override val clock: Clock
get() = overrideClock ?: throw UnsupportedOperationException() get() = overrideClock ?: throw UnsupportedOperationException()
override val myInfo: NodeInfo override val myInfo: NodeInfo
@ -79,11 +73,6 @@ open class MockServiceHubInternal(
lateinit var smm: StateMachineManager lateinit var smm: StateMachineManager
override fun <T : SerializeAsToken> cordaService(type: Class<T>): T = throw UnsupportedOperationException() override fun <T : SerializeAsToken> cordaService(type: Class<T>): T = throw UnsupportedOperationException()
override fun <T> startFlow(logic: FlowLogic<T>, flowInitiator: FlowInitiator, ourIdentity: Party?): FlowStateMachineImpl<T> {
return smm.executor.fetchFrom { smm.add(logic, flowInitiator, ourIdentity) }
}
override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? = null override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? = null
override fun jdbcSession(): Connection = database.createSession() override fun jdbcSession(): Connection = database.createSession()