Deprecated FlowLogic.getCounterpartyMarker as it's complicated and probably not used (replacement is to use sub-flows).

Also made flow registration require the client flow class rather than any old class.
This commit is contained in:
Shams Asari
2017-04-25 12:21:15 +01:00
parent 922a760a09
commit 913487cb32
31 changed files with 123 additions and 145 deletions

View File

@ -230,7 +230,7 @@ abstract class MQSecurityTest : NodeBasedTest() {
private fun startBobAndCommunicateWithAlice(): Party {
val bob = startNode(BOB.name).getOrThrow()
bob.services.registerFlowInitiator(SendFlow::class.java, ::ReceiveFlow)
bob.services.registerServiceFlow(SendFlow::class.java, ::ReceiveFlow)
val bobParty = bob.info.legalIdentity
// Perform a protocol exchange to force the peer queue to be created
alice.services.startFlow(SendFlow(bobParty, 0)).resultFuture.getOrThrow()

View File

@ -52,7 +52,6 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.configureDatabase
import net.corda.node.utilities.transaction
import org.apache.activemq.artemis.utils.ReusableLatch
import org.bouncycastle.asn1.x500.X500Name
import org.jetbrains.exposed.sql.Database
import org.slf4j.Logger
import java.io.IOException
@ -108,7 +107,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
// low-performance prototyping period.
protected abstract val serverThread: AffinityExecutor
private val flowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>()
private val serviceFlowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>()
protected val partyKeys = mutableSetOf<KeyPair>()
val services = object : ServiceHubInternal() {
@ -132,14 +131,14 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
return serverThread.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {
require(markerClass !in flowFactories) { "${markerClass.name} has already been used to register a flow" }
log.info("Registering flow ${markerClass.name}")
flowFactories[markerClass] = flowFactory
override fun registerServiceFlow(clientFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) {
require(clientFlowClass !in serviceFlowFactories) { "${clientFlowClass.name} has already been used to register a service flow" }
log.info("Registering service flow for ${clientFlowClass.name}")
serviceFlowFactories[clientFlowClass] = serviceFlowFactory
}
override fun getFlowFactory(markerClass: Class<*>): ((Party) -> FlowLogic<*>)? {
return flowFactories[markerClass]
override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)? {
return serviceFlowFactories[clientFlowClass]
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) {
@ -236,7 +235,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
false
}
startMessagingService(rpcOps)
services.registerFlowInitiator(ContractUpgradeFlow.Instigator::class.java) { ContractUpgradeFlow.Acceptor(it) }
services.registerServiceFlow(ContractUpgradeFlow.Instigator::class.java) { ContractUpgradeFlow.Acceptor(it) }
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMapIfConfigured())
smm.start()

View File

@ -17,7 +17,7 @@ object NotaryChange {
*/
class Service(services: PluginServiceHub) : SingletonSerializeAsToken() {
init {
services.registerFlowInitiator(NotaryChangeFlow.Instigator::class.java) { NotaryChangeFlow.Acceptor(it) }
services.registerServiceFlow(NotaryChangeFlow.Instigator::class.java) { NotaryChangeFlow.Acceptor(it) }
}
}
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.api
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRefFactory
@ -85,7 +86,8 @@ abstract class ServiceHubInternal : PluginServiceHub {
* Note that you must be on the server thread to call this method. [flowInitiator] points how flow was started,
* See: [FlowInitiator].
*
* @throws IllegalFlowLogicException or IllegalArgumentException if there are problems with the [logicType] or [args].
* @throws net.corda.core.flows.IllegalFlowLogicException or IllegalArgumentException if there are problems with the
* [logicType] or [args].
*/
fun <T : Any> invokeFlowAsync(
logicType: Class<out FlowLogic<T>>,
@ -96,4 +98,6 @@ abstract class ServiceHubInternal : PluginServiceHub {
val logic = flowLogicRefFactory.toFlowLogic(logicRef) as FlowLogic<T>
return startFlow(logic, flowInitiator)
}
abstract fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)?
}

View File

@ -34,9 +34,9 @@ object DataVending {
@ThreadSafe
class Service(services: PluginServiceHub) : SingletonSerializeAsToken() {
init {
services.registerFlowInitiator(FetchTransactionsFlow::class.java, ::FetchTransactionsHandler)
services.registerFlowInitiator(FetchAttachmentsFlow::class.java, ::FetchAttachmentsHandler)
services.registerFlowInitiator(BroadcastTransactionFlow::class.java, ::NotifyTransactionHandler)
services.registerServiceFlow(FetchTransactionsFlow::class.java, ::FetchTransactionsHandler)
services.registerServiceFlow(FetchAttachmentsFlow::class.java, ::FetchAttachmentsHandler)
services.registerServiceFlow(BroadcastTransactionFlow::class.java, ::NotifyTransactionHandler)
}
private class FetchTransactionsHandler(otherParty: Party) : FetchDataHandler<SignedTransaction>(otherParty) {

View File

@ -11,11 +11,7 @@ import net.corda.core.ErrorOr
import net.corda.core.abbreviate
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.flows.*
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.random63BitValue
@ -34,6 +30,7 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import rx.Observable
import java.lang.reflect.Modifier
import java.sql.Connection
import java.sql.SQLException
import java.util.*
@ -304,8 +301,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Initiating a new session with $otherParty" }
val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
// We get the top-most concrete class object to cater for the case where the client flow is customised via a sub-class
val clientFlowClass = sessionFlow.topConcreteFlowClass
val sessionInit = SessionInit(session.ourSessionId, clientFlowClass, firstPayload)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()
@ -313,6 +311,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return session
}
@Suppress("UNCHECKED_CAST")
private val FlowLogic<*>.topConcreteFlowClass: Class<out FlowLogic<*>> get() {
var current: Class<out FlowLogic<*>> = javaClass
while (!Modifier.isAbstract(current.superclass.modifiers)) {
current = current.superclass as Class<out FlowLogic<*>>
}
return current
}
@Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.UntrustworthyData
@ -12,7 +13,9 @@ import net.corda.core.utilities.UntrustworthyData
@CordaSerializable
interface SessionMessage
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
data class SessionInit(val initiatorSessionId: Long,
val clientFlowClass: Class<out FlowLogic<*>>,
val firstPayload: Any?) : SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long

View File

@ -19,11 +19,7 @@ import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.commonName
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.flows.*
import net.corda.core.messaging.ReceivedMessage
import net.corda.core.messaging.TopicSession
import net.corda.core.messaging.send
@ -345,18 +341,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(otherPartySessionId, message))
val markerClass = try {
Class.forName(sessionInit.flowName)
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionReject("Don't know ${sessionInit.flowName}")
return
}
val flowFactory = serviceHub.getFlowFactory(markerClass)
val flowFactory = serviceHub.getServiceFlowFactory(sessionInit.clientFlowClass)
if (flowFactory == null) {
logger.warn("Unknown flow marker class in $sessionInit")
sendSessionReject("Don't know ${markerClass.name}")
logger.warn("${sessionInit.clientFlowClass} has not been registered with a service flow: $sessionInit")
sendSessionReject("Don't know ${sessionInit.clientFlowClass.name}")
return
}
@ -378,7 +366,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using marker ${markerClass.name}" }
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.clientFlowClass.name}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber)
}

View File

@ -18,7 +18,7 @@ import net.corda.node.services.api.ServiceHubInternal
abstract class NotaryService(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
services.registerFlowInitiator(NotaryFlow.Client::class.java) { createFlow(it) }
services.registerServiceFlow(NotaryFlow.Client::class.java) { createFlow(it) }
}
/** Implement a factory that specifies the transaction commit flow for the notary service to use */

View File

@ -22,7 +22,6 @@ import net.corda.testing.MOCK_IDENTITY_SERVICE
import net.corda.testing.node.MockNetworkMapCache
import net.corda.testing.node.MockStorageService
import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
open class MockServiceHubInternal(
val customVault: VaultService? = null,
@ -68,8 +67,6 @@ open class MockServiceHubInternal(
private val txStorageService: TxWritableStorageService
get() = storage ?: throw UnsupportedOperationException()
private val flowFactories = ConcurrentHashMap<Class<*>, (Party) -> FlowLogic<*>>()
lateinit var smm: StateMachineManager
init {
@ -86,11 +83,7 @@ open class MockServiceHubInternal(
return smm.executor.fetchFrom { smm.add(logic, flowInitiator) }
}
override fun registerFlowInitiator(markerClass: Class<*>, flowFactory: (Party) -> FlowLogic<*>) {
flowFactories[markerClass] = flowFactory
}
override fun registerServiceFlow(clientFlowClass: Class<out FlowLogic<*>>, serviceFlowFactory: (Party) -> FlowLogic<*>) = Unit
override fun getFlowFactory(markerClass: Class<*>): ((Party) -> FlowLogic<*>)? {
return flowFactories[markerClass]
}
override fun getServiceFlowFactory(clientFlowClass: Class<out FlowLogic<*>>): ((Party) -> FlowLogic<*>)? = null
}

View File

@ -89,7 +89,7 @@ class DataVendingServiceTests {
}
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.services.registerFlowInitiator(NotifyTxFlow::class.java, ::NotifyTransactionHandler)
walletServiceNode.services.registerServiceFlow(NotifyTxFlow::class.java, ::NotifyTransactionHandler)
services.startFlow(NotifyTxFlow(walletServiceNode.info.legalIdentity, tx))
network.runNetwork()
}

View File

@ -8,7 +8,6 @@ import net.corda.core.*
import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState
import net.corda.core.crypto.Party
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
@ -111,7 +110,7 @@ class StateMachineManagerTests {
@Test
fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow("Hello", it) }
node2.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow("Hello", it) }
val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
@ -130,7 +129,7 @@ class StateMachineManagerTests {
@Test
fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node2.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity))
// We push through just enough messages to get only the payload sent
@ -180,7 +179,7 @@ class StateMachineManagerTests {
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
node1.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow("Hello", it) }
node1.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow("Hello", it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
@ -244,8 +243,8 @@ class StateMachineManagerTests {
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node2.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
node3.services.registerServiceFlow(SendFlow::class.java) { ReceiveFlow(it).nonTerminating() }
val payload = "Hello World"
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork()
@ -278,8 +277,8 @@ class StateMachineManagerTests {
net.runNetwork()
val node2Payload = "Test 1"
val node3Payload = "Test 2"
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveFlow::class.java) { SendFlow(node3Payload, it) }
node2.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow(node2Payload, it) }
node3.services.registerServiceFlow(ReceiveFlow::class.java) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
node1.services.startFlow(multiReceiveFlow)
node1.acceptableLiveFiberCountOnStop = 1
@ -304,7 +303,7 @@ class StateMachineManagerTests {
@Test
fun `both sides do a send as their first IO request`() {
node2.services.registerFlowInitiator(PingPongFlow::class.java) { PingPongFlow(it, 20L) }
node2.services.registerServiceFlow(PingPongFlow::class.java) { PingPongFlow(it, 20L) }
node1.services.startFlow(PingPongFlow(node2.info.legalIdentity, 10L))
net.runNetwork()
@ -340,7 +339,7 @@ class StateMachineManagerTests {
sessionTransfers.expectEvents(isStrict = false) {
sequence(
// First Pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -350,7 +349,7 @@ class StateMachineManagerTests {
assertEquals(notary1.id, it.from)
},
// Second pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -360,7 +359,7 @@ class StateMachineManagerTests {
assertEquals(notary2.id, it.from)
},
// Third pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
expect(match = { it.message is SessionInit && it.message.clientFlowClass == NotaryFlow.Client::class.java }) {
it.message as SessionInit
assertEquals(node1.id, it.from)
assertEquals(notary1Address, it.to)
@ -375,7 +374,7 @@ class StateMachineManagerTests {
@Test
fun `other side ends before doing expected send`() {
node2.services.registerFlowInitiator(ReceiveFlow::class.java) { NoOpFlow() }
node2.services.registerServiceFlow(ReceiveFlow::class.java) { NoOpFlow() }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
@ -535,7 +534,7 @@ class StateMachineManagerTests {
}
}
node2.services.registerFlowInitiator(AskForExceptionFlow::class.java) { ConditionalExceptionFlow(it, "Hello") }
node2.services.registerServiceFlow(AskForExceptionFlow::class.java) { ConditionalExceptionFlow(it, "Hello") }
val resultFuture = node1.services.startFlow(RetryOnExceptionFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(resultFuture.getOrThrow()).isEqualTo("Hello")
@ -563,7 +562,7 @@ class StateMachineManagerTests {
ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction()
node1.services.registerFlowInitiator(WaitingFlows.Waiter::class.java) {
node1.services.registerServiceFlow(WaitingFlows.Waiter::class.java) {
WaitingFlows.Committer(it) { throw Exception("Error") }
}
val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
@ -580,6 +579,14 @@ class StateMachineManagerTests {
assertThatThrownBy { result.getOrThrow() }.hasMessageContaining("Vault").hasMessageContaining("private method")
}
@Test
fun `custom client flow`() {
val receiveFlowFuture = node2.initiateSingleShotFlow(SendFlow::class) { ReceiveFlow(it) }
node1.services.startFlow(CustomSendFlow("Hello", node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThat(receiveFlowFuture.getOrThrow().receivedPayloads).containsOnly("Hello")
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers
@ -598,7 +605,9 @@ class StateMachineManagerTests {
return smm.findStateMachines(P::class.java).single()
}
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java, payload)
}
private val sessionConfirm = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload)
private val normalEnd = NormalSessionEnd(0)
@ -663,7 +672,7 @@ class StateMachineManagerTests {
}
private class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() {
private open class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
}
@ -672,6 +681,8 @@ class StateMachineManagerTests {
override fun call() = otherParties.forEach { send(it, payload) }
}
private interface CustomInterface
private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
object START_STEP : ProgressTracker.Step("Starting")