diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index d6752b4383..ced7bc8ea9 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -53,15 +53,19 @@ abstract class FlowLogic { */ val serviceHub: ServiceHub get() = stateMachine.serviceHub + @Suspendable + fun initiateFlow(party: Party): FlowSession = stateMachine.initiateFlow(party, flowUsedForSessions) + /** - * Returns a [FlowContext] object describing the flow [otherParty] is using. With [FlowContext.flowVersion] it + * Returns a [FlowInfo] object describing the flow [otherParty] is using. With [FlowInfo.flowVersion] it * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. * * This method can be called before any send or receive has been done with [otherParty]. In such a case this will force * them to start their flow. */ + @Deprecated("Use FlowSession.getFlowInfo()", level = DeprecationLevel.WARNING) @Suspendable - fun getFlowContext(otherParty: Party): FlowContext = stateMachine.getFlowContext(otherParty, flowUsedForSessions) + fun getFlowInfo(otherParty: Party): FlowInfo = stateMachine.getFlowInfo(otherParty, flowUsedForSessions) /** * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response @@ -76,6 +80,7 @@ abstract class FlowLogic { * * @returns an [UntrustworthyData] wrapper around the received object. */ + @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) inline fun sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData { return sendAndReceive(R::class.java, otherParty, payload) } @@ -91,6 +96,7 @@ abstract class FlowLogic { * * @returns an [UntrustworthyData] wrapper around the received object. */ + @Deprecated("Use FlowSession.sendAndReceive()", level = DeprecationLevel.WARNING) @Suspendable open fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions) @@ -105,8 +111,13 @@ abstract class FlowLogic { * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. */ + @Deprecated("Use FlowSession.sendAndReceiveWithRetry()", level = DeprecationLevel.WARNING) internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, true) + return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, retrySend = true) + } + @Suspendable + internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { + return stateMachine.sendAndReceive(receiveType, counterparty, payload, flowUsedForSessions, retrySend = true) } /** @@ -116,6 +127,7 @@ abstract class FlowLogic { * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly * corrupted data in order to exploit your code. */ + @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) inline fun receive(otherParty: Party): UntrustworthyData = receive(R::class.java, otherParty) /** @@ -127,6 +139,7 @@ abstract class FlowLogic { * * @returns an [UntrustworthyData] wrapper around the received object. */ + @Deprecated("Use FlowSession.receive()", level = DeprecationLevel.WARNING) @Suspendable open fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { return stateMachine.receive(receiveType, otherParty, flowUsedForSessions) @@ -139,6 +152,7 @@ abstract class FlowLogic { * is offline then message delivery will be retried until it comes back or until the message is older than the * network's event horizon time. */ + @Deprecated("Use FlowSession.send()", level = DeprecationLevel.WARNING) @Suspendable open fun send(otherParty: Party, payload: Any) = stateMachine.send(otherParty, payload, flowUsedForSessions) @@ -294,7 +308,7 @@ abstract class FlowLogic { * Version and name of the CorDapp hosting the other side of the flow. */ @CordaSerializable -data class FlowContext( +data class FlowInfo( /** * The integer flow version the other side is using. * @see InitiatingFlow diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt b/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt new file mode 100644 index 0000000000..e3423bd302 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/FlowSession.kt @@ -0,0 +1,109 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.identity.Party +import net.corda.core.utilities.UntrustworthyData + +/** + * To port existing flows: + * + * Look for [Deprecated] usages of send/receive/sendAndReceive/getFlowInfo. + * + * If it's an InitiatingFlow: + * + * Look for the send/receive that kicks off the counter flow. Insert a + * + * val session = initiateFlow(party) + * + * and use this session afterwards for send/receives. + * For example: + * send(party, something) + * will become + * session.send(something) + * + * If it's an InitiatedBy flow: + * + * Change the constructor to take an initiatingSession: FlowSession instead of a counterparty: Party + * Then look for usages of the deprecated functions and change them to use the FlowSession + * For example: + * send(counterparty, something) + * will become + * initiatingSession.send(something) + */ +abstract class FlowSession { + abstract val counterparty: Party + + /** + * Returns a [FlowInfo] object describing the flow [counterparty] is using. With [FlowInfo.flowVersion] it + * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. + * + * This method can be called before any send or receive has been done with [counterparty]. In such a case this will force + * them to start their flow. + */ + @Suspendable + abstract fun getCounterpartyFlowInfo(): FlowInfo + + /** + * Serializes and queues the given [payload] object for sending to the [counterparty]. Suspends until a response + * is received, which must be of the given [R] type. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to + * use this when you expect to do a message swap than do use [send] and then [receive] in turn. + * + * @returns an [UntrustworthyData] wrapper around the received object. + */ + @Suspendable + inline fun sendAndReceive(payload: Any): UntrustworthyData { + return sendAndReceive(R::class.java, payload) + } + /** + * Serializes and queues the given [payload] object for sending to the [counterparty]. Suspends until a response + * is received, which must be of the given [receiveType]. Remember that when receiving data from other parties the data + * should not be trusted until it's been thoroughly verified for consistency and that all expectations are + * satisfied, as a malicious peer may send you subtly corrupted data in order to exploit your code. + * + * Note that this function is not just a simple send+receive pair: it is more efficient and more correct to + * use this when you expect to do a message swap than do use [send] and then [receive] in turn. + * + * @returns an [UntrustworthyData] wrapper around the received object. + */ + @Suspendable + abstract fun sendAndReceive(receiveType: Class, payload: Any): UntrustworthyData + + /** + * Suspends until [counterparty] sends us a message of type [R]. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + */ + @Suspendable + inline fun receive(): UntrustworthyData { + return receive(R::class.java) + } + /** + * Suspends until [counterparty] sends us a message of type [receiveType]. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns an [UntrustworthyData] wrapper around the received object. + */ + @Suspendable + abstract fun receive(receiveType: Class): UntrustworthyData + + /** + * Queues the given [payload] for sending to the [counterparty] and continues without suspending. + * + * Note that the other party may receive the message at some arbitrary later point or not at all: if [counterparty] + * is offline then message delivery will be retried until it comes back or until the message is older than the + * network's event horizon time. + */ + @Suspendable + abstract fun send(payload: Any) +} diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index c855b8ea70..8fec6891c3 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -13,7 +13,10 @@ import org.slf4j.Logger /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ interface FlowStateMachine { @Suspendable - fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext + fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>): FlowInfo + + @Suspendable + fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession @Suspendable fun sendAndReceive(receiveType: Class, diff --git a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt index 49b415fe4e..1f8573bbaf 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt @@ -147,7 +147,7 @@ class AttachmentSerializationTest { private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) { server.internals.internalRegisterFlowFactory( ClientLogic::class.java, - InitiatedFlowFactory.Core { ServerLogic(it, sendData) }, + InitiatedFlowFactory.Core { ServerLogic(it.counterparty, sendData) }, ServerLogic::class.java, track = false) client.services.startFlow(clientLogic) diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt index 1b271eae18..8f54ff553d 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt @@ -34,14 +34,14 @@ class FlowVersioningTest : NodeBasedTest() { val alicePlatformVersionAccordingToBob = receive(initiatedParty).unwrap { it } return Pair( alicePlatformVersionAccordingToBob, - getFlowContext(initiatedParty).flowVersion + getFlowInfo(initiatedParty).flowVersion ) } } private class PretendInitiatedCoreFlow(val initiatingParty: Party) : FlowLogic() { @Suspendable - override fun call() = send(initiatingParty, getFlowContext(initiatingParty).flowVersion) + override fun call() = send(initiatingParty, getFlowInfo(initiatingParty).flowVersion) } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 8046e535d1..168d2c252b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -293,14 +293,35 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, return registerInitiatedFlowInternal(initiatedFlowClass, track = true) } + // TODO remove once not needed + private fun deprecatedFlowConstructorMessage(flowClass: Class<*>): String { + return "Installing flow factory for $flowClass accepting a ${Party::class.java.simpleName}, which is deprecated. " + + "It should accept a ${FlowSession::class.java.simpleName} instead" + } + private fun > registerInitiatedFlowInternal(initiatedFlow: Class, track: Boolean): Observable { - val ctor = initiatedFlow.getDeclaredConstructor(Party::class.java).apply { isAccessible = true } + val constructors = initiatedFlow.declaredConstructors.associateBy { it.parameterTypes.toList() } + val flowSessionCtor = constructors[listOf(FlowSession::class.java)]?.apply { isAccessible = true } + val ctor: (FlowSession) -> F = if (flowSessionCtor == null) { + // Try to fallback to a Party constructor + val partyCtor = constructors[listOf(Party::class.java)]?.apply { isAccessible = true } + if (partyCtor == null) { + throw IllegalArgumentException("$initiatedFlow must have a constructor accepting a ${FlowSession::class.java.name}") + } else { + log.warn(deprecatedFlowConstructorMessage(initiatedFlow)) + } + @Suppress("UNCHECKED_CAST") + { flowSession: FlowSession -> partyCtor.newInstance(flowSession.counterparty) as F } + } else { + @Suppress("UNCHECKED_CAST") + { flowSession: FlowSession -> flowSessionCtor.newInstance(flowSession) as F } + } val initiatingFlow = initiatedFlow.requireAnnotation().value.java val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass require(classWithAnnotation == initiatingFlow) { "${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiatingFlow.name}" } - val flowFactory = InitiatedFlowFactory.CorDapp(version, initiatedFlow.appName, { ctor.newInstance(it) }) + val flowFactory = InitiatedFlowFactory.CorDapp(version, initiatedFlow.appName, ctor) val observable = internalRegisterFlowFactory(initiatingFlow, flowFactory, initiatedFlow, track) log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)") return observable @@ -326,8 +347,15 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, * compatibility [flowFactory] provides a second parameter which is the platform version of the initiating party. * @suppress */ + @Deprecated("Use installCoreFlowExpectingFlowSession() instead") @VisibleForTesting fun installCoreFlow(clientFlowClass: KClass>, flowFactory: (Party) -> FlowLogic<*>) { + log.warn(deprecatedFlowConstructorMessage(clientFlowClass.java)) + installCoreFlowExpectingFlowSession(clientFlowClass, { flowSession -> flowFactory(flowSession.counterparty) }) + } + + @VisibleForTesting + fun installCoreFlowExpectingFlowSession(clientFlowClass: KClass>, flowFactory: (FlowSession) -> FlowLogic<*>) { require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) { "${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version" } @@ -335,6 +363,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, log.debug { "Installed core flow ${clientFlowClass.java.name}" } } + private fun installCoreFlows() { installCoreFlow(BroadcastTransactionFlow::class, ::NotifyTransactionHandler) installCoreFlow(NotaryChangeFlow::class, ::NotaryChangeHandler) diff --git a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt index aaa5053627..f259512109 100644 --- a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt +++ b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt @@ -1,15 +1,15 @@ package net.corda.node.internal import net.corda.core.flows.FlowLogic -import net.corda.core.identity.Party +import net.corda.core.flows.FlowSession sealed class InitiatedFlowFactory> { - protected abstract val factory: (Party) -> F - fun createFlow(otherParty: Party): F = factory(otherParty) + protected abstract val factory: (FlowSession) -> F + fun createFlow(initiatingFlowSession: FlowSession): F = factory(initiatingFlowSession) - data class Core>(override val factory: (Party) -> F) : InitiatedFlowFactory() + data class Core>(override val factory: (FlowSession) -> F) : InitiatedFlowFactory() data class CorDapp>(val flowVersion: Int, val appName: String, - override val factory: (Party) -> F) : InitiatedFlowFactory() + override val factory: (FlowSession) -> F) : InitiatedFlowFactory() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt index 8a1d2342e3..748cce9bd8 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt @@ -11,7 +11,7 @@ interface FlowIORequest { interface WaitingRequest : FlowIORequest interface SessionedFlowIORequest : FlowIORequest { - val session: FlowSession + val session: FlowSessionInternal } interface SendRequest : SessionedFlowIORequest { @@ -23,7 +23,7 @@ interface ReceiveRequest : SessionedFlowIORequest, WaitingRe val userReceiveType: Class<*>? } -data class SendAndReceive(override val session: FlowSession, +data class SendAndReceive(override val session: FlowSessionInternal, override val message: SessionMessage, override val receiveType: Class, override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest { @@ -31,14 +31,14 @@ data class SendAndReceive(override val session: FlowSession, override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class ReceiveOnly(override val session: FlowSession, +data class ReceiveOnly(override val session: FlowSessionInternal, override val receiveType: Class, override val userReceiveType: Class<*>?) : ReceiveRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class SendOnly(override val session: FlowSession, override val message: SessionMessage) : SendRequest { +data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt new file mode 100644 index 0000000000..044fdc0dbe --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt @@ -0,0 +1,43 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine +import net.corda.core.utilities.UntrustworthyData + +class FlowSessionImpl( + override val counterparty: Party +) : FlowSession() { + + internal lateinit var stateMachine: FlowStateMachine<*> + internal lateinit var sessionFlow: FlowLogic<*> + + @Suspendable + override fun getCounterpartyFlowInfo(): FlowInfo { + return stateMachine.getFlowInfo(counterparty, sessionFlow) + } + + @Suspendable + override fun sendAndReceive(receiveType: Class, payload: Any): UntrustworthyData { + return stateMachine.sendAndReceive(receiveType, counterparty, payload, sessionFlow) + } + + @Suspendable + internal fun sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { + return stateMachine.sendAndReceive(receiveType, counterparty, payload, sessionFlow, retrySend = true) + } + + @Suspendable + override fun receive(receiveType: Class): UntrustworthyData { + return stateMachine.receive(receiveType, counterparty, sessionFlow) + } + + @Suspendable + override fun send(payload: Any) { + return stateMachine.send(counterparty, payload, sessionFlow) + } +} + diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt similarity index 93% rename from node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt rename to node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt index e17cf976a1..58f08dfa63 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt @@ -1,6 +1,6 @@ package net.corda.node.services.statemachine -import net.corda.core.flows.FlowContext +import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party import net.corda.node.services.statemachine.FlowSessionState.Initiated @@ -12,7 +12,8 @@ import java.util.concurrent.ConcurrentLinkedQueue * is received. Note that this requires the party on the other end to be a distributed service and run an idempotent flow * that only sends back a single [SessionData] message before termination. */ -class FlowSession( +// TODO rename this +class FlowSessionInternal( val flow: FlowLogic<*>, val ourSessionId: Long, val initiatingParty: Party?, @@ -42,7 +43,7 @@ sealed class FlowSessionState { override val sendToParty: Party get() = otherParty } - data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowContext) : FlowSessionState() { + data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() { override val sendToParty: Party get() = peerParty } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index df6daf71e6..56d48162ec 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -11,12 +11,9 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.internal.FlowStateMachine -import net.corda.core.internal.abbreviate +import net.corda.core.internal.* import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.openFuture -import net.corda.core.internal.isRegularFile -import net.corda.core.internal.staticField import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.* import net.corda.node.services.api.FlowAppAuditEvent @@ -86,7 +83,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, get() = _resultFuture ?: openFuture().also { _resultFuture = it } // This state IS serialised, as we need it to know what the fiber is waiting for. - internal val openSessions = HashMap, Party>, FlowSession>() + internal val openSessions = HashMap, Party>, FlowSessionInternal>() internal var waitingForResponse: WaitingRequest? = null internal var hasSoftLockedStates: Boolean = false set(value) { @@ -158,7 +155,15 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext { + override fun initiateFlow(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession { + val flowSession = FlowSessionImpl(otherParty) + flowSession.stateMachine = this + flowSession.sessionFlow = sessionFlow + return flowSession + } + + @Suspendable + override fun getFlowInfo(otherParty: Party, sessionFlow: FlowLogic<*>): FlowInfo { val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated return state.context } @@ -279,20 +284,20 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, * This method will suspend the state machine and wait for incoming session init response from other party. */ @Suspendable - private fun FlowSession.waitForConfirmation() { + private fun FlowSessionInternal.waitForConfirmation() { val (peerParty, sessionInitResponse) = receiveInternal(this, null) if (sessionInitResponse is SessionConfirm) { state = FlowSessionState.Initiated( peerParty, sessionInitResponse.initiatedSessionId, - FlowContext(sessionInitResponse.flowVersion, sessionInitResponse.appName)) + FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName)) } else { sessionInitResponse as SessionReject throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") } } - private fun createSessionData(session: FlowSession, payload: Any): SessionData { + private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData { val sessionState = session.state val peerSessionId = when (sessionState) { is FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session") @@ -302,23 +307,23 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun sendInternal(session: FlowSession, message: SessionMessage) = suspend(SendOnly(session, message)) + private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message)) private inline fun receiveInternal( - session: FlowSession, + session: FlowSessionInternal, userReceiveType: Class<*>?): ReceivedSessionMessage { return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType)) } private inline fun sendAndReceiveInternal( - session: FlowSession, + session: FlowSessionInternal, message: SessionMessage, userReceiveType: Class<*>?): ReceivedSessionMessage { return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType)) } @Suspendable - private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession? { + private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal? { return openSessions[Pair(sessionFlow, otherParty)]?.apply { if (state is FlowSessionState.Initiating) { // Session still initiating, wait for the confirmation @@ -328,7 +333,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession { + private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSessionInternal { return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?: startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) } @@ -344,9 +349,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, sessionFlow: FlowLogic<*>, firstPayload: Any?, waitForConfirmation: Boolean, - retryable: Boolean = false): FlowSession { + retryable: Boolean = false): FlowSessionInternal { logger.trace { "Initiating a new session with $otherParty" } - val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty), retryable) + val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty), retryable) openSessions[Pair(sessionFlow, otherParty)] = session val (version, initiatingFlowClass) = sessionFlow.javaClass.flowVersionAndInitiatingClass val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, sessionFlow.javaClass.appName, firstPayload) @@ -403,7 +408,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } } - private fun FlowSession.erroredEnd(end: ErrorSessionEnd): Nothing { + private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing { if (end.errorResponse != null) { @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") (end.errorResponse as java.lang.Throwable).fillInStackTrace() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 8895db530c..08ed716df5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -136,7 +136,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private val totalStartedFlows = metrics.counter("Flows.Started") private val totalFinishedFlows = metrics.counter("Flows.Finished") - private val openSessions = ConcurrentHashMap() + private val openSessions = ConcurrentHashMap() private val recentlyClosedSessions = ConcurrentHashMap() internal val tokenizableServices = ArrayList() @@ -341,7 +341,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger // commit but a counterparty flow has ended with an error (in which case our flow also has to end) - private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSession): Boolean { + private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { val waitingForResponse = session.fiber.waitingForResponse return (waitingForResponse as? ReceiveRequest<*>)?.session === session || waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd @@ -355,21 +355,24 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val (session, initiatedFlowFactory) = try { val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) - val flow = initiatedFlowFactory.createFlow(sender) + val flowSession = FlowSessionImpl(sender) + val flow = initiatedFlowFactory.createFlow(flowSession) val senderFlowVersion = when (initiatedFlowFactory) { is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion } - val session = FlowSession( + val session = FlowSessionInternal( flow, random63BitValue(), sender, - FlowSessionState.Initiated(sender, senderSessionId, FlowContext(senderFlowVersion, sessionInit.appName))) + FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) if (sessionInit.firstPayload != null) { session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) } openSessions[session.ourSessionId] = session val fiber = createFiber(flow, FlowInitiator.Peer(sender)) + flowSession.sessionFlow = flow + flowSession.stateMachine = fiber fiber.openSessions[Pair(flow, sender)] = session updateCheckpoint(fiber) session to initiatedFlowFactory @@ -484,7 +487,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun FlowSession.endSession(exception: Throwable?, propagated: Boolean) { + private fun FlowSessionInternal.endSession(exception: Throwable?, propagated: Boolean) { val initiatedState = state as? FlowSessionState.Initiated ?: return val sessionEnd = if (exception == null) { NormalSessionEnd(initiatedState.peerSessionId) diff --git a/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt b/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt index f6dd8060af..4d0e050364 100644 --- a/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt +++ b/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt @@ -4,10 +4,7 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.* import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party -import net.corda.core.internal.copyToDirectory -import net.corda.core.internal.createDirectories -import net.corda.core.internal.div -import net.corda.core.internal.list +import net.corda.core.internal.* import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap @@ -66,14 +63,14 @@ class CordappSmokeTest { @InitiatingFlow @StartableByRPC - class GatherContextsFlow(private val otherParty: Party) : FlowLogic>() { + class GatherContextsFlow(private val otherParty: Party) : FlowLogic>() { @Suspendable - override fun call(): Pair { + override fun call(): Pair { // This receive will kick off SendBackInitiatorFlowContext by sending a session-init with our app name. // SendBackInitiatorFlowContext will send back our context using the information from this session-init - val sessionInitContext = receive(otherParty).unwrap { it } + val sessionInitContext = receive(otherParty).unwrap { it } // This context is taken from the session-confirm message - val sessionConfirmContext = getFlowContext(otherParty) + val sessionConfirmContext = getFlowInfo(otherParty) return Pair(sessionInitContext, sessionConfirmContext) } } @@ -84,7 +81,7 @@ class CordappSmokeTest { @Suspendable override fun call() { // An initiated flow calling getFlowContext on its initiator will get the context from the session-init - val sessionInitContext = getFlowContext(otherParty) + val sessionInitContext = getFlowInfo(otherParty) send(otherParty, sessionInitContext) } } diff --git a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt index a7cdbe34d7..3762380ab2 100644 --- a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt +++ b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt @@ -1,16 +1,13 @@ package net.corda.node import com.fasterxml.jackson.dataformat.yaml.YAMLFactory -import net.corda.core.concurrent.CordaFuture +import com.nhaarman.mockito_kotlin.mock +import net.corda.client.jackson.JacksonSupport import net.corda.core.contracts.Amount import net.corda.core.crypto.SecureHash -import net.corda.core.flows.* +import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine -import net.corda.core.node.ServiceHub -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.UntrustworthyData -import net.corda.client.jackson.JacksonSupport import net.corda.core.utilities.ProgressTracker import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.shell.InteractiveShell @@ -18,7 +15,6 @@ import net.corda.testing.DUMMY_CA import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP_IDENTITY import org.junit.Test -import org.slf4j.Logger import java.util.* import kotlin.test.assertEquals @@ -51,9 +47,11 @@ class InteractiveShellTest { check("b: 12, c: Yo", "12Yo") } - @Test fun flowStartWithComplexTypes() = check("amount: £10", "10.00 GBP") + @Test + fun flowStartWithComplexTypes() = check("amount: £10", "10.00 GBP") - @Test fun flowStartWithNestedTypes() = check( + @Test + fun flowStartWithNestedTypes() = check( "pair: { first: $100.12, second: df489807f81c8c8829e509e1bcb92e6692b9dd9d624b7456435cb2f51dc82587 }", "($100.12, df489807f81c8c8829e509e1bcb92e6692b9dd9d624b7456435cb2f51dc82587)" ) @@ -70,30 +68,5 @@ class InteractiveShellTest { @Test fun party() = check("party: \"${MEGA_CORP.name}\"", MEGA_CORP.name.toString()) - class DummyFSM(val logic: FlowA) : FlowStateMachine { - override fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext { - throw UnsupportedOperationException("not implemented") - } - override fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, retrySend: Boolean): UntrustworthyData { - throw UnsupportedOperationException("not implemented") - } - override fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { - throw UnsupportedOperationException("not implemented") - } - override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { - throw UnsupportedOperationException("not implemented") - } - override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction { - throw UnsupportedOperationException("not implemented") - } - override val serviceHub: ServiceHub get() = throw UnsupportedOperationException() - override val logger: Logger get() = throw UnsupportedOperationException() - override val id: StateMachineRunId get() = throw UnsupportedOperationException() - override val resultFuture: CordaFuture get() = throw UnsupportedOperationException() - override val flowInitiator: FlowInitiator get() = throw UnsupportedOperationException() - override fun checkFlowPermission(permissionName: String, extraAuditData: Map) = Unit - override fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) = Unit - override fun flowStackSnapshot(flowClass: Class>): FlowStackSnapshot? = null - override fun persistFlowStackSnapshot(flowClass: Class>) = Unit - } + class DummyFSM(val logic: FlowA) : FlowStateMachine by mock() } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 03bef9cb85..e6fcb52dae 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -692,7 +692,7 @@ class FlowFrameworkTests { node1 sent sessionInit(SendFlow::class, flowVersion = 1, payload = "Old initiating") to node2, node2 sent sessionConfirm(flowVersion = 2) to node1 ) - assertThat(initiatingFlow.getFlowContext(node2.info.legalIdentity).flowVersion).isEqualTo(2) + assertThat(initiatingFlow.getFlowInfo(node2.info.legalIdentity).flowVersion).isEqualTo(2) } @Test @@ -756,10 +756,19 @@ class FlowFrameworkTests { return smm.findStateMachines(P::class.java).single() } + @Deprecated("Use registerFlowFactoryExpectingFlowSession() instead") private inline fun > StartedNode<*>.registerFlowFactory( initiatingFlowClass: KClass>, initiatedFlowVersion: Int = 1, noinline flowFactory: (Party) -> P): CordaFuture

+ { + return registerFlowFactoryExpectingFlowSession(initiatingFlowClass, initiatedFlowVersion, { flowFactory(it.counterparty) }) + } + + private inline fun > StartedNode<*>.registerFlowFactoryExpectingFlowSession( + initiatingFlowClass: KClass>, + initiatedFlowVersion: Int = 1, + noinline flowFactory: (FlowSession) -> P): CordaFuture

{ val observable = internals.internalRegisterFlowFactory( initiatingFlowClass.java, @@ -976,7 +985,7 @@ class FlowFrameworkTests { @Suspendable override fun call(): Pair { val received = receive(otherParty).unwrap { it } - val otherFlowVersion = getFlowContext(otherParty).flowVersion + val otherFlowVersion = getFlowInfo(otherParty).flowVersion return Pair(received, otherFlowVersion) } }