Merge pull request #582 from corda/shams-flow-counterpartymarker-cleanup

Deprecated FlowLogic.getCounterpartyMarker as it's complicated and probably not used
This commit is contained in:
Shams Asari 2017-04-26 09:40:48 +01:00 committed by GitHub
commit c5a9312e07
31 changed files with 123 additions and 145 deletions

View File

@ -42,11 +42,9 @@ abstract class FlowLogic<out T> {
*/ */
val serviceHub: ServiceHub get() = stateMachine.serviceHub val serviceHub: ServiceHub get() = stateMachine.serviceHub
/** @Deprecated("This is no longer used and will be removed in a future release. If you are using this to communicate " +
* Return the marker [Class] which [party] has used to register the counterparty flow that is to execute on the "with the same party but for two different message streams, then the correct way of doing that is to use sub-flows",
* other side. The default implementation returns the class object of this FlowLogic, but any [Class] instance level = DeprecationLevel.ERROR)
* will do as long as the other side registers with it.
*/
open fun getCounterpartyMarker(party: Party): Class<*> = javaClass open fun getCounterpartyMarker(party: Party): Class<*> = javaClass
/** /**
@ -190,9 +188,10 @@ abstract class FlowLogic<out T> {
private var _stateMachine: FlowStateMachine<*>? = null private var _stateMachine: FlowStateMachine<*>? = null
/** /**
* Internal only. Reference to the [Fiber] instance that is the top level controller for the entire flow. When * @suppress
* inside a flow this is equivalent to [Strand.currentStrand]. This is public only because it must be accessed * Internal only. Reference to the [co.paralleluniverse.fibers.Fiber] instance that is the top level controller for
* across module boundaries. * the entire flow. When inside a flow this is equivalent to [co.paralleluniverse.strands.Strand.currentStrand]. This
* is public only because it must be accessed across module boundaries.
*/ */
var stateMachine: FlowStateMachine<*> var stateMachine: FlowStateMachine<*>
get() = _stateMachine ?: throw IllegalStateException("This can only be done after the flow has been started.") get() = _stateMachine ?: throw IllegalStateException("This can only be done after the flow has been started.")

View File

@ -2,31 +2,26 @@ package net.corda.core.node
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import kotlin.reflect.KClass
/** /**
* A service hub to be used by the [CordaPluginRegistry] * A service hub to be used by the [CordaPluginRegistry]
*/ */
interface PluginServiceHub : ServiceHub { interface PluginServiceHub : ServiceHub {
/** /**
* Register the flow factory we wish to use when a initiating party attempts to communicate with us. The * Register the service flow factory to use when an initiating party attempts to communicate with us. The registration
* registration is done against a marker [Class] which is sent in the session handshake by the other party. If this * is done against the [Class] object of the client flow to the service flow. What this means is if a counterparty
* marker class has been registered then the corresponding factory will be used to create the flow which will * starts a [FlowLogic] represented by [clientFlowClass] and starts communication with us, we will execute the service
* communicate with the other side. If there is no mapping then the session attempt is rejected. * flow produced by [serviceFlowFactory]. This service flow has respond correctly to the sends and receives the client
* @param markerClass The marker [Class] present in a session initiation attempt. Conventionally this is a [FlowLogic] * does.
* subclass, however any class can be used, with the default being the class of the initiating flow. This enables * @param clientFlowClass [Class] of the client flow involved in this client-server communication.
* the registration to be of the form: `registerFlowInitiator(InitiatorFlow.class, InitiatedFlow::new)` * @param serviceFlowFactory Lambda which produces a new service flow for each new client flow communication. The
* @param flowFactory The flow factory generating the initiated flow. * [Party] parameter of the factory is the client's identity.
*/ */
fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) fun registerServiceFlow(clientFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>)
@Deprecated(message = "Use overloaded method which uses Class instead of KClass. This is scheduled for removal in a future release.") @Suppress("UNCHECKED_CAST")
fun registerFlowInitiator(markerClass: KClass<*>, flowFactory: (Party) -> FlowLogic<*>) { @Deprecated("This is scheduled to be removed in a future release", ReplaceWith("registerServiceFlow"))
registerFlowInitiator(markerClass.java, flowFactory) fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {
registerServiceFlow(markerClass as Class<out FlowLogic<*>>, flowFactory)
} }
/**
* Return the flow factory that has been registered with [markerClass], or null if no factory is found.
*/
fun getFlowFactory(markerClass: Class<*>): ((Party) -> FlowLogic<*>)?
} }

View File

@ -82,6 +82,7 @@ object DefaultKryoCustomizer {
register(MetaData::class.java, MetaDataSerializer) register(MetaData::class.java, MetaDataSerializer)
register(BitSet::class.java, BitSetSerializer()) register(BitSet::class.java, BitSetSerializer())
register(Class::class.java, ClassSerializer)
addDefaultSerializer(Logger::class.java, LoggerSerializer) addDefaultSerializer(Logger::class.java, LoggerSerializer)

View File

@ -565,6 +565,17 @@ object LoggerSerializer : Serializer<Logger>() {
} }
} }
object ClassSerializer : Serializer<Class<*>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Class<*>>): Class<*> {
val className = input.readString()
return Class.forName(className)
}
override fun write(kryo: Kryo, output: Output, clazz: Class<*>) {
output.writeString(clazz.name)
}
}
/** /**
* For serialising an [X500Name] without touching Sun internal classes. * For serialising an [X500Name] without touching Sun internal classes.
*/ */

View File

@ -1,13 +1,10 @@
package net.corda.flows package net.corda.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.DealState import net.corda.core.contracts.DealState
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.services.ServiceType
import net.corda.core.seconds import net.corda.core.seconds
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
@ -38,12 +35,6 @@ object TwoPartyDealFlow {
@CordaSerializable @CordaSerializable
class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySigs: List<DigitalSignature.WithKey>) class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySigs: List<DigitalSignature.WithKey>)
/**
* [Primary] at the end sends the signed tx to all the regulator parties. This a seperate workflow which needs a
* sepearate session with the regulator. This interface is used to do that in [Primary.getCounterpartyMarker].
*/
interface MarkerForBogusRegulatorFlow
/** /**
* Abstracted bilateral deal flow participant that initiates communication/handshake. * Abstracted bilateral deal flow participant that initiates communication/handshake.
* *
@ -69,14 +60,6 @@ object TwoPartyDealFlow {
abstract val otherParty: Party abstract val otherParty: Party
abstract val myKeyPair: KeyPair abstract val myKeyPair: KeyPair
override fun getCounterpartyMarker(party: Party): Class<*> {
return if (serviceHub.networkMapCache.regulatorNodes.any { it.legalIdentity == party }) {
MarkerForBogusRegulatorFlow::class.java
} else {
super.getCounterpartyMarker(party)
}
}
@Suspendable @Suspendable
fun getPartialTransaction(): UntrustworthyData<SignedTransaction> { fun getPartialTransaction(): UntrustworthyData<SignedTransaction> {
progressTracker.currentStep = AWAITING_PROPOSAL progressTracker.currentStep = AWAITING_PROPOSAL
@ -146,9 +129,8 @@ object TwoPartyDealFlow {
progressTracker.currentStep = COPYING_TO_REGULATOR progressTracker.currentStep = COPYING_TO_REGULATOR
val regulators = serviceHub.networkMapCache.regulatorNodes val regulators = serviceHub.networkMapCache.regulatorNodes
if (regulators.isNotEmpty()) { if (regulators.isNotEmpty()) {
// Copy the transaction to every regulator in the network. This is obviously completely bogus, it's // If there are regulators in the network, then we could copy them in on the transaction via a sub-flow
// just for demo purposes. // which would simply send them the transaction.
regulators.forEach { send(it.serviceIdentities(ServiceType.regulator).first(), fullySigned) }
} }
return fullySigned return fullySigned

View File

@ -30,7 +30,7 @@ public class FlowsInJavaTest {
@Test @Test
public void suspendableActionInsideUnwrap() throws Exception { public void suspendableActionInsideUnwrap() throws Exception {
node2.getServices().registerFlowInitiator(SendInUnwrapFlow.class, (otherParty) -> new OtherFlow(otherParty, "Hello")); node2.getServices().registerServiceFlow(SendInUnwrapFlow.class, (otherParty) -> new OtherFlow(otherParty, "Hello"));
Future<String> result = node1.getServices().startFlow(new SendInUnwrapFlow(node2.getInfo().getLegalIdentity())).getResultFuture(); Future<String> result = node1.getServices().startFlow(new SendInUnwrapFlow(node2.getInfo().getLegalIdentity())).getResultFuture();
net.runNetwork(); net.runNetwork();
assertThat(result.get()).isEqualTo("Hello"); assertThat(result.get()).isEqualTo("Hello");

View File

@ -15,7 +15,7 @@ import java.security.cert.Certificate
*/ */
object TxKeyFlow { object TxKeyFlow {
fun registerFlowInitiator(services: PluginServiceHub) { fun registerFlowInitiator(services: PluginServiceHub) {
services.registerFlowInitiator(Requester::class.java, ::Provider) services.registerServiceFlow(Requester::class.java, ::Provider)
} }
class Requester(val otherSide: Party, class Requester(val otherSide: Party,

View File

@ -24,7 +24,7 @@ import java.util.*
object FxTransactionDemoTutorial { object FxTransactionDemoTutorial {
// Would normally be called by custom service init in a CorDapp // Would normally be called by custom service init in a CorDapp
fun registerFxProtocols(pluginHub: PluginServiceHub) { fun registerFxProtocols(pluginHub: PluginServiceHub) {
pluginHub.registerFlowInitiator(ForeignExchangeFlow::class.java, ::ForeignExchangeRemoteFlow) pluginHub.registerServiceFlow(ForeignExchangeFlow::class.java, ::ForeignExchangeRemoteFlow)
} }
} }

View File

@ -17,7 +17,7 @@ import java.time.Duration
object WorkflowTransactionBuildTutorial { object WorkflowTransactionBuildTutorial {
// Would normally be called by custom service init in a CorDapp // Would normally be called by custom service init in a CorDapp
fun registerWorkflowProtocols(pluginHub: PluginServiceHub) { fun registerWorkflowProtocols(pluginHub: PluginServiceHub) {
pluginHub.registerFlowInitiator(SubmitCompletionFlow::class.java, ::RecordCompletionFlow) pluginHub.registerServiceFlow(SubmitCompletionFlow::class.java, ::RecordCompletionFlow)
} }
} }

View File

@ -96,7 +96,7 @@ object IssuerFlow {
class Service(services: PluginServiceHub) { class Service(services: PluginServiceHub) {
init { init {
services.registerFlowInitiator(IssuanceRequester::class.java, ::Issuer) services.registerServiceFlow(IssuanceRequester::class.java, ::Issuer)
} }
} }
} }

View File

@ -2,11 +2,8 @@
package net.corda.nodeapi package net.corda.nodeapi
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Registration import com.esotericsoftware.kryo.Registration
import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.serialization.* import net.corda.core.serialization.*
@ -71,17 +68,6 @@ open class RPCException(msg: String, cause: Throwable?) : RuntimeException(msg,
class DeadlineExceeded(rpcName: String) : RPCException("Deadline exceeded on call to $rpcName") class DeadlineExceeded(rpcName: String) : RPCException("Deadline exceeded on call to $rpcName")
} }
object ClassSerializer : Serializer<Class<*>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Class<*>>): Class<*> {
val className = input.readString()
return Class.forName(className)
}
override fun write(kryo: Kryo, output: Output, clazz: Class<*>) {
output.writeString(clazz.name)
}
}
@CordaSerializable @CordaSerializable
class PermissionException(msg: String) : RuntimeException(msg) class PermissionException(msg: String) : RuntimeException(msg)
@ -99,7 +85,6 @@ class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(mak
DefaultKryoCustomizer.customize(this) DefaultKryoCustomizer.customize(this)
// RPC specific classes // RPC specific classes
register(Class::class.java, ClassSerializer)
register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer) register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer)
register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class)) register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class))
register(Observable::class.java, observableSerializer) register(Observable::class.java, observableSerializer)

View File

@ -230,7 +230,7 @@ abstract class MQSecurityTest : NodeBasedTest() {
private fun startBobAndCommunicateWithAlice(): Party { private fun startBobAndCommunicateWithAlice(): Party {
val bob = startNode(BOB.name).getOrThrow() val bob = startNode(BOB.name).getOrThrow()
bob.services.registerFlowInitiator(SendFlow::class.java, ::ReceiveFlow) bob.services.registerServiceFlow(SendFlow::class.java, ::ReceiveFlow)
val bobParty = bob.info.legalIdentity val bobParty = bob.info.legalIdentity
// Perform a protocol exchange to force the peer queue to be created // Perform a protocol exchange to force the peer queue to be created
alice.services.startFlow(SendFlow(bobParty, 0)).resultFuture.getOrThrow() alice.services.startFlow(SendFlow(bobParty, 0)).resultFuture.getOrThrow()

View File

@ -52,7 +52,6 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.node.utilities.transaction import net.corda.node.utilities.transaction
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.bouncycastle.asn1.x500.X500Name
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.slf4j.Logger import org.slf4j.Logger
import java.io.IOException import java.io.IOException
@ -108,7 +107,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
// low-performance prototyping period. // low-performance prototyping period.
protected abstract val serverThread: AffinityExecutor protected abstract val serverThread: AffinityExecutor
private val flowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>() private val serviceFlowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>()
protected val partyKeys = mutableSetOf<KeyPair>() protected val partyKeys = mutableSetOf<KeyPair>()
val services = object : ServiceHubInternal() { val services = object : ServiceHubInternal() {
@ -132,14 +131,14 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
return serverThread.fetchFrom { smm.add(logic, flowInitiator) } return serverThread.fetchFrom { smm.add(logic, flowInitiator) }
} }
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { override fun registerServiceFlow(clientFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) {
require(markerClass !in flowFactories) { "${markerClass.name} has already been used to register a flow" } require(clientFlowClass !in serviceFlowFactories) { "${clientFlowClass.name} has already been used to register a service flow" }
log.info("Registering flow ${markerClass.name}") log.info("Registering service flow for ${clientFlowClass.name}")
flowFactories[markerClass] = flowFactory serviceFlowFactories[clientFlowClass] = serviceFlowFactory
} }
override fun getFlowFactory(markerClass: Class<*>): ((Party) -> FlowLogic<*>)? { override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)? {
return flowFactories[markerClass] return serviceFlowFactories[clientFlowClass]
} }
override fun recordTransactions(txs: Iterable<SignedTransaction>) { override fun recordTransactions(txs: Iterable<SignedTransaction>) {
@ -236,7 +235,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
false false
} }
startMessagingService(rpcOps) startMessagingService(rpcOps)
services.registerFlowInitiator(ContractUpgradeFlow.Instigator::class.java) { ContractUpgradeFlow.Acceptor(it) } services.registerServiceFlow(ContractUpgradeFlow.Instigator::class.java) { ContractUpgradeFlow.Acceptor(it) }
runOnStop += Runnable { net.stop() } runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMapIfConfigured()) _networkMapRegistrationFuture.setFuture(registerWithNetworkMapIfConfigured())
smm.start() smm.start()

View File

@ -17,7 +17,7 @@ object NotaryChange {
*/ */
class Service(services: PluginServiceHub) : SingletonSerializeAsToken() { class Service(services: PluginServiceHub) : SingletonSerializeAsToken() {
init { init {
services.registerFlowInitiator(NotaryChangeFlow.Instigator::class.java) { NotaryChangeFlow.Acceptor(it) } services.registerServiceFlow(NotaryChangeFlow.Instigator::class.java) { NotaryChangeFlow.Acceptor(it) }
} }
} }
} }

View File

@ -2,6 +2,7 @@ package net.corda.node.services.api
import com.google.common.annotations.VisibleForTesting import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
@ -85,7 +86,8 @@ abstract class ServiceHubInternal : PluginServiceHub {
* Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started, * Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started,
* See: [FlowInitiator]. * See: [FlowInitiator].
* *
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args]. * @throws net.corda.core.flows.IllegalFlowLogicException or IllegalArgumentException if there are problems with the
* [logicType] or [args].
*/ */
fun <T : Any> invokeFlowAsync( fun <T : Any> invokeFlowAsync(
logicType: Class<out FlowLogic<T>>, logicType: Class<out FlowLogic<T>>,
@ -96,4 +98,6 @@ abstract class ServiceHubInternal : PluginServiceHub {
val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T> val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T>
return startFlow(logic, flowInitiator) return startFlow(logic, flowInitiator)
} }
abstract fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)?
} }

View File

@ -34,9 +34,9 @@ object DataVending {
@ThreadSafe @ThreadSafe
class Service(services: PluginServiceHub) : SingletonSerializeAsToken() { class Service(services: PluginServiceHub) : SingletonSerializeAsToken() {
init { init {
services.registerFlowInitiator(FetchTransactionsFlow::class.java, ::FetchTransactionsHandler) services.registerServiceFlow(FetchTransactionsFlow::class.java, ::FetchTransactionsHandler)
services.registerFlowInitiator(FetchAttachmentsFlow::class.java, ::FetchAttachmentsHandler) services.registerServiceFlow(FetchAttachmentsFlow::class.java, ::FetchAttachmentsHandler)
services.registerFlowInitiator(BroadcastTransactionFlow::class.java, ::NotifyTransactionHandler) services.registerServiceFlow(BroadcastTransactionFlow::class.java, ::NotifyTransactionHandler)
} }
private class FetchTransactionsHandler(otherParty: Party) : FetchDataHandler<SignedTransaction>(otherParty) { private class FetchTransactionsHandler(otherParty: Party) : FetchDataHandler<SignedTransaction>(otherParty) {

View File

@ -11,11 +11,7 @@ import net.corda.core.ErrorOr
import net.corda.core.abbreviate import net.corda.core.abbreviate
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException import net.corda.core.flows.*
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
@ -34,6 +30,7 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import rx.Observable import rx.Observable
import java.lang.reflect.Modifier
import java.sql.Connection import java.sql.Connection
import java.sql.SQLException import java.sql.SQLException
import java.util.* import java.util.*
@ -304,8 +301,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Initiating a new session with $otherParty" } logger.trace { "Initiating a new session with $otherParty" }
val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty)) val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session openSessions[Pair(sessionFlow, otherParty)] = session
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name // We get the top-most concrete class object to cater for the case where the client flow is customised via a sub-class
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload) val clientFlowClass = sessionFlow.topConcreteFlowClass
val sessionInit = SessionInit(session.ourSessionId, clientFlowClass, firstPayload)
sendInternal(session, sessionInit) sendInternal(session, sessionInit)
if (waitForConfirmation) { if (waitForConfirmation) {
session.waitForConfirmation() session.waitForConfirmation()
@ -313,6 +311,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return session return session
} }
@Suppress("UNCHECKED_CAST")
private val FlowLogic<*>.topConcreteFlowClass: Class<out FlowLogic<*>> get() {
var current: Class<out FlowLogic<*>> = javaClass
while (!Modifier.isAbstract(current.superclass.modifiers)) {
current = current.superclass as Class<out FlowLogic<*>>
}
return current
}
@Suspendable @Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> { private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
@ -12,7 +13,9 @@ import net.corda.core.utilities.UntrustworthyData
@CordaSerializable @CordaSerializable
interface SessionMessage interface SessionMessage
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage data class SessionInit(val initiatorSessionId: Long,
val clientFlowClass: Class<out FlowLogic<*>>,
val firstPayload: Any?) : SessionMessage
interface ExistingSessionMessage : SessionMessage { interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long val recipientSessionId: Long

View File

@ -19,11 +19,7 @@ import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.commonName import net.corda.core.crypto.commonName
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.*
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.ReceivedMessage import net.corda.core.messaging.ReceivedMessage
import net.corda.core.messaging.TopicSession import net.corda.core.messaging.TopicSession
import net.corda.core.messaging.send import net.corda.core.messaging.send
@ -345,18 +341,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(otherPartySessionId, message)) fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(otherPartySessionId, message))
val markerClass = try { val flowFactory = serviceHub.getServiceFlowFactory(sessionInit.clientFlowClass)
Class.forName(sessionInit.flowName)
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionReject("Don't know ${sessionInit.flowName}")
return
}
val flowFactory = serviceHub.getFlowFactory(markerClass)
if (flowFactory == null) { if (flowFactory == null) {
logger.warn("Unknown flow marker class in $sessionInit") logger.warn("${sessionInit.clientFlowClass} has not been registered with a service flow: $sessionInit")
sendSessionReject("Don't know ${markerClass.name}") sendSessionReject("Don't know ${sessionInit.clientFlowClass.name}")
return return
} }
@ -378,7 +366,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), session.fiber) sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using marker ${markerClass.name}" } session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.clientFlowClass.name}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" } session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber) resumeFiber(session.fiber)
} }

View File

@ -18,7 +18,7 @@ import net.corda.node.services.api.ServiceHubInternal
abstract class NotaryService(services: ServiceHubInternal) : SingletonSerializeAsToken() { abstract class NotaryService(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init { init {
services.registerFlowInitiator(NotaryFlow.Client::class.java) { createFlow(it) } services.registerServiceFlow(NotaryFlow.Client::class.java) { createFlow(it) }
} }
/** Implement a factory that specifies the transaction commit flow for the notary service to use */ /** Implement a factory that specifies the transaction commit flow for the notary service to use */

View File

@ -22,7 +22,6 @@ import net.corda.testing.MOCK_IDENTITY_SERVICE
import net.corda.testing.node.MockNetworkMapCache import net.corda.testing.node.MockNetworkMapCache
import net.corda.testing.node.MockStorageService import net.corda.testing.node.MockStorageService
import java.time.Clock import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
open class MockServiceHubInternal( open class MockServiceHubInternal(
val customVault: VaultService? = null, val customVault: VaultService? = null,
@ -68,8 +67,6 @@ open class MockServiceHubInternal(
private val txStorageService: TxWritableStorageService private val txStorageService: TxWritableStorageService
get() = storage ?: throw UnsupportedOperationException() get() = storage ?: throw UnsupportedOperationException()
private val flowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>()
lateinit var smm: StateMachineManager lateinit var smm: StateMachineManager
init { init {
@ -86,11 +83,7 @@ open class MockServiceHubInternal(
return smm.executor.fetchFrom { smm.add(logic, flowInitiator) } return smm.executor.fetchFrom { smm.add(logic, flowInitiator) }
} }
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) { override fun registerServiceFlow(clientFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) = Unit
flowFactories[markerClass] = flowFactory
}
override fun getFlowFactory(markerClass: Class<*>): ((Party) -> FlowLogic<*>)? { override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)? = null
return flowFactories[markerClass]
}
} }

View File

@ -89,7 +89,7 @@ class DataVendingServiceTests {
} }
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) { private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.services.registerFlowInitiator(NotifyTxFlow::class.java, ::NotifyTransactionHandler) walletServiceNode.services.registerServiceFlow(NotifyTxFlow::class.java, ::NotifyTransactionHandler)
services.startFlow(NotifyTxFlow(walletServiceNode.info.legalIdentity, tx)) services.startFlow(NotifyTxFlow(walletServiceNode.info.legalIdentity, tx))
network.runNetwork() network.runNetwork()
} }

View File

@ -8,7 +8,6 @@ import net.corda.core.*
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState import net.corda.core.contracts.DummyState
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
@ -111,7 +110,7 @@ class StateMachineManagerTests {
@Test @Test
fun `exception while fiber suspended`() { fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow("Hello", it) } node2.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow("Hello", it) }
val flow = ReceiveFlow(node2.info.legalIdentity) val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception // Before the flow runs change the suspend action to throw an exception
@ -130,7 +129,7 @@ class StateMachineManagerTests {
@Test @Test
fun `flow restarted just after receiving payload`() { fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() } node2.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity)) node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity))
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
@ -180,7 +179,7 @@ class StateMachineManagerTests {
@Test @Test
fun `flow loaded from checkpoint will respond to messages from before start`() { fun `flow loaded from checkpoint will respond to messages from before start`() {
node1.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow("Hello", it) } node1.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow("Hello", it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
node2.smm.executor.flush() node2.smm.executor.flush()
@ -244,8 +243,8 @@ class StateMachineManagerTests {
fun `sending to multiple parties`() { fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address) val node3 = net.createNode(node1.info.address)
net.runNetwork() net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() } node2.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() } node3.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
val payload = "Hello World" val payload = "Hello World"
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity)) node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork() net.runNetwork()
@ -278,8 +277,8 @@ class StateMachineManagerTests {
net.runNetwork() net.runNetwork()
val node2Payload = "Test 1" val node2Payload = "Test 1"
val node3Payload = "Test 2" val node3Payload = "Test 2"
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow(node2Payload, it) } node2.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow(node3Payload, it) } node3.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating() val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
node1.services.startFlow(multiReceiveFlow) node1.services.startFlow(multiReceiveFlow)
node1.acceptableLiveFiberCountOnStop = 1 node1.acceptableLiveFiberCountOnStop = 1
@ -304,7 +303,7 @@ class StateMachineManagerTests {
@Test @Test
fun `both sides do a send as their first IO request`() { fun `both sides do a send as their first IO request`() {
node2.services.registerFlowInitiator(PingPongFlow::class.java) { PingPongFlow(it, 20L) } node2.services.registerServiceFlow(PingPongFlow::class.java) { PingPongFlow(it, 20L) }
node1.services.startFlow(PingPongFlow(node2.info.legalIdentity, 10L)) node1.services.startFlow(PingPongFlow(node2.info.legalIdentity, 10L))
net.runNetwork() net.runNetwork()
@ -340,7 +339,7 @@ class StateMachineManagerTests {
sessionTransfers.expectEvents(isStrict = false) { sessionTransfers.expectEvents(isStrict = false) {
sequence( sequence(
// First Pay // First Pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit it.message as SessionInit
assertEquals(node1.id, it.from) assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to) assertEquals(notary1Address, it.to)
@ -350,7 +349,7 @@ class StateMachineManagerTests {
assertEquals(notary1.id, it.from) assertEquals(notary1.id, it.from)
}, },
// Second pay // Second pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit it.message as SessionInit
assertEquals(node1.id, it.from) assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to) assertEquals(notary1Address, it.to)
@ -360,7 +359,7 @@ class StateMachineManagerTests {
assertEquals(notary2.id, it.from) assertEquals(notary2.id, it.from)
}, },
// Third pay // Third pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit it.message as SessionInit
assertEquals(node1.id, it.from) assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to) assertEquals(notary1Address, it.to)
@ -375,7 +374,7 @@ class StateMachineManagerTests {
@Test @Test
fun `other side ends before doing expected send`() { fun `other side ends before doing expected send`() {
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { NoOpFlow() } node2.services.registerServiceFlow(ReceiveFlow::class.java) { NoOpFlow() }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork() net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
@ -535,7 +534,7 @@ class StateMachineManagerTests {
} }
} }
node2.services.registerFlowInitiator(AskForExceptionFlow::class.java) { ConditionalExceptionFlow(it, "Hello") } node2.services.registerServiceFlow(AskForExceptionFlow::class.java) { ConditionalExceptionFlow(it, "Hello") }
val resultFuture = node1.services.startFlow(RetryOnExceptionFlow(node2.info.legalIdentity)).resultFuture val resultFuture = node1.services.startFlow(RetryOnExceptionFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork() net.runNetwork()
assertThat(resultFuture.getOrThrow()).isEqualTo("Hello") assertThat(resultFuture.getOrThrow()).isEqualTo("Hello")
@ -563,7 +562,7 @@ class StateMachineManagerTests {
ptx.signWith(node1.services.legalIdentityKey) ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction() val stx = ptx.toSignedTransaction()
node1.services.registerFlowInitiator(WaitingFlows.Waiter::class.java) { node1.services.registerServiceFlow(WaitingFlows.Waiter::class.java) {
WaitingFlows.Committer(it) { throw Exception("Error") } WaitingFlows.Committer(it) { throw Exception("Error") }
} }
val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
@ -580,6 +579,14 @@ class StateMachineManagerTests {
assertThatThrownBy { result.getOrThrow() }.hasMessageContaining("Vault").hasMessageContaining("private method") assertThatThrownBy { result.getOrThrow() }.hasMessageContaining("Vault").hasMessageContaining("private method")
} }
@Test
fun `custom client flow`() {
val receiveFlowFuture = node2.initiateSingleShotFlow(SendFlow::class) { ReceiveFlow(it) }
node1.services.startFlow(CustomSendFlow("Hello", node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello")
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers //region Helpers
@ -598,7 +605,9 @@ class StateMachineManagerTests {
return smm.findStateMachines(P::class.java).single() return smm.findStateMachines(P::class.java).single()
} }
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload) private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java, payload)
}
private val sessionConfirm = SessionConfirm(0, 0) private val sessionConfirm = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload) private fun sessionData(payload: Any) = SessionData(0, payload)
private val normalEnd = NormalSessionEnd(0) private val normalEnd = NormalSessionEnd(0)
@ -663,7 +672,7 @@ class StateMachineManagerTests {
} }
private class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() { private open class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() {
init { init {
require(otherParties.isNotEmpty()) require(otherParties.isNotEmpty())
} }
@ -672,6 +681,8 @@ class StateMachineManagerTests {
override fun call() = otherParties.forEach { send(it, payload) } override fun call() = otherParties.forEach { send(it, payload) }
} }
private interface CustomInterface
private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() { private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
object START_STEP : ProgressTracker.Step("Starting") object START_STEP : ProgressTracker.Step("Starting")

View File

@ -75,8 +75,8 @@ object NodeInterestRates {
// Note: access to the singleton oracle property is via the registered SingletonSerializeAsToken Service. // Note: access to the singleton oracle property is via the registered SingletonSerializeAsToken Service.
// Otherwise the Kryo serialisation of the call stack in the Quasar Fiber extends to include // Otherwise the Kryo serialisation of the call stack in the Quasar Fiber extends to include
// the framework Oracle and the flow will crash. // the framework Oracle and the flow will crash.
services.registerFlowInitiator(RatesFixFlow.FixSignFlow::class.java) { FixSignHandler(it, this) } services.registerServiceFlow(RatesFixFlow.FixSignFlow::class.java) { FixSignHandler(it, this) }
services.registerFlowInitiator(RatesFixFlow.FixQueryFlow::class.java) { FixQueryHandler(it, this) } services.registerServiceFlow(RatesFixFlow.FixQueryFlow::class.java) { FixQueryHandler(it, this) }
} }
private class FixSignHandler(val otherParty: Party, val service: Service) : FlowLogic<Unit>() { private class FixSignHandler(val otherParty: Party, val service: Service) : FlowLogic<Unit>() {

View File

@ -31,7 +31,7 @@ object AutoOfferFlow {
class Service(services: PluginServiceHub) : SingletonSerializeAsToken() { class Service(services: PluginServiceHub) : SingletonSerializeAsToken() {
init { init {
services.registerFlowInitiator(Instigator::class.java) { Acceptor(it) } services.registerServiceFlow(Instigator::class.java) { Acceptor(it) }
} }
} }

View File

@ -24,7 +24,7 @@ object FixingFlow {
class Service(services: PluginServiceHub) { class Service(services: PluginServiceHub) {
init { init {
services.registerFlowInitiator(Floater::class.java) { Fixer(it) } services.registerServiceFlow(Floater::class.java) { Fixer(it) }
} }
} }

View File

@ -30,7 +30,7 @@ object UpdateBusinessDayFlow {
class Service(services: PluginServiceHub) { class Service(services: PluginServiceHub) {
init { init {
services.registerFlowInitiator(Broadcast::class.java, ::UpdateBusinessDayHandler) services.registerServiceFlow(Broadcast::class.java, ::UpdateBusinessDayHandler)
} }
} }

View File

@ -44,7 +44,7 @@ object IRSTradeFlow {
class Service(services: PluginServiceHub) { class Service(services: PluginServiceHub) {
init { init {
services.registerFlowInitiator(Requester::class.java, ::Receiver) services.registerServiceFlow(Requester::class.java, ::Receiver)
} }
} }

View File

@ -184,7 +184,7 @@ object SimmFlow {
*/ */
class Service(services: PluginServiceHub) { class Service(services: PluginServiceHub) {
init { init {
services.registerFlowInitiator(Requester::class.java, ::Receiver) services.registerServiceFlow(Requester::class.java, ::Receiver)
} }
} }

View File

@ -31,7 +31,7 @@ class BuyerFlow(val otherParty: Party,
it.automaticallyExtractAttachments = true it.automaticallyExtractAttachments = true
it.storePath it.storePath
} }
services.registerFlowInitiator(SellerFlow::class.java) { BuyerFlow(it, attachmentsPath.toString()) } services.registerServiceFlow(SellerFlow::class.java) { BuyerFlow(it, attachmentsPath.toString()) }
} }
} }

View File

@ -140,14 +140,14 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
/** /**
* The given flow factory will be used to initiate just one instance of a flow of type [P] when a counterparty * The given flow factory will be used to initiate just one instance of a flow of type [P] when a counterparty
* flow requests for it using [markerClass]. * flow requests for it using [clientFlowClass].
* @return Returns a [ListenableFuture] holding the single [FlowStateMachineImpl] created by the request. * @return Returns a [ListenableFuture] holding the single [FlowStateMachineImpl] created by the request.
*/ */
inline fun <reified P : FlowLogic<*>> AbstractNode.initiateSingleShotFlow( inline fun <reified P : FlowLogic<*>> AbstractNode.initiateSingleShotFlow(
markerClass: KClass<out FlowLogic<*>>, clientFlowClass: KClass<out FlowLogic<*>>,
noinline flowFactory: (Party) -> P): ListenableFuture<P> { noinline flowFactory: (Party) -> P): ListenableFuture<P> {
val future = smm.changes.filter { it is StateMachineManager.Change.Add && it.logic is P }.map { it.logic as P }.toFuture() val future = smm.changes.filter { it is StateMachineManager.Change.Add && it.logic is P }.map { it.logic as P }.toFuture()
services.registerFlowInitiator(markerClass.java, flowFactory) services.registerServiceFlow(clientFlowClass.java, flowFactory)
return future return future
} }