diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index ea32e377af..59fe50ba61 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -203,7 +203,7 @@ abstract class FlowLogic { val theirs = subLogic.progressTracker if (ours != null && theirs != null) { if (ours.currentStep == ProgressTracker.UNSTARTED) { - logger.warn("ProgressTracker has not been started for $this") + logger.warn("ProgressTracker has not been started") ours.nextStep() } ours.setChildProgressTracker(ours.currentStep, theirs) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt index 9496feeb2e..5820facedc 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt @@ -10,6 +10,8 @@ interface FlowIORequest { val stackTraceInCaseOfProblems: StackSnapshot } +interface WaitingRequest : FlowIORequest + interface SessionedFlowIORequest : FlowIORequest { val session: FlowSession } @@ -18,7 +20,7 @@ interface SendRequest : SessionedFlowIORequest { val message: SessionMessage } -interface ReceiveRequest : SessionedFlowIORequest { +interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { val receiveType: Class } @@ -40,7 +42,7 @@ data class SendOnly(override val session: FlowSession, override val message: Ses override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { +data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4beab873e2..8ad9335c5a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -51,7 +51,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Transient override lateinit var serviceHub: ServiceHubInternal @Transient internal lateinit var database: Database @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit - @Transient internal lateinit var actionOnEnd: (Pair?) -> Unit + @Transient internal lateinit var actionOnEnd: (Throwable?, Boolean) -> Unit @Transient internal var fromCheckpoint: Boolean = false @Transient private var txTrampoline: Transaction? = null @@ -76,7 +76,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, // This state IS serialised, as we need it to know what the fiber is waiting for. internal val openSessions = HashMap, Party>, FlowSession>() - internal var waitingForLedgerCommitOf: SecureHash? = null + internal var waitingForResponse: WaitingRequest? = null init { logic.stateMachine = this @@ -91,11 +91,11 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } catch (e: FlowException) { // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). val propagated = e.stackTrace[0].className == javaClass.name - actionOnEnd(Pair(e, propagated)) + actionOnEnd(e, propagated) _resultFuture?.setException(e) return } catch (t: Throwable) { - actionOnEnd(null) + actionOnEnd(t, false) _resultFuture?.setException(t) throw ExecutionException(t) } @@ -105,7 +105,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, .filter { it.state is FlowSessionState.Initiating } .forEach { it.waitForConfirmation() } // This is to prevent actionOnEnd being called twice if it throws an exception - actionOnEnd(null) + actionOnEnd(null, false) _resultFuture?.set(result) } @@ -136,10 +136,11 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, sessionFlow: FlowLogic<*>): UntrustworthyData { val session = getConfirmedSession(otherParty, sessionFlow) return if (session == null) { + val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true) // Only do a receive here as the session init has carried the payload - receiveInternal(startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true)) + receiveInternal(newSession, receiveType) } else { - sendAndReceiveInternal(session, createSessionData(session, payload)) + sendAndReceiveInternal(session, createSessionData(session, payload), receiveType) }.checkPayloadIs(receiveType) } @@ -147,8 +148,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { - val session = getConfirmedSession(otherParty, sessionFlow) ?: startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) - return receiveInternal(session).checkPayloadIs(receiveType) + val session = getConfirmedSession(otherParty, sessionFlow) ?: + startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) + return receiveInternal(session, receiveType).checkPayloadIs(receiveType) } @Suspendable @@ -167,7 +169,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, */ @Suspendable private fun FlowSession.waitForConfirmation() { - val (peerParty, sessionInitResponse) = receiveInternal(this) + val (peerParty, sessionInitResponse) = receiveInternal(this, null) if (sessionInitResponse is SessionConfirm) { state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId) } else { @@ -178,12 +180,19 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction { - waitingForLedgerCommitOf = hash logger.info("Waiting for transaction $hash to commit") suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>)) - logger.info("Transaction $hash has committed to the ledger, resuming") val stx = serviceHub.storageService.validatedTransactions.getTransaction(hash) - return stx ?: throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage") + if (stx != null) return stx + // If the tx isn't committed then we may have been resumed due to an session ending in an error + for (session in openSessions.values) { + for (receivedMessage in session.receivedMessages) { + if (receivedMessage.message is ErrorSessionEnd) { + session.erroredEnd(receivedMessage.message) + } + } + } + throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage") } private fun createSessionData(session: FlowSession, payload: Any): SessionData { @@ -200,14 +209,17 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, suspend(SendOnly(session, message)) } - private inline fun receiveInternal(session: FlowSession): ReceivedSessionMessage { - return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)) + private inline fun receiveInternal( + session: FlowSession, + userReceiveType: Class<*>?): ReceivedSessionMessage { + return waitForMessage(ReceiveOnly(session, M::class.java), userReceiveType) } private inline fun sendAndReceiveInternal( session: FlowSession, - message: SessionMessage): ReceivedSessionMessage { - return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) + message: SessionMessage, + userReceiveType: Class<*>?): ReceivedSessionMessage { + return waitForMessage(SendAndReceive(session, message, M::class.java), userReceiveType) } @Suspendable @@ -241,51 +253,72 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - @Suppress("UNCHECKED_CAST", "PLATFORM_CLASS_MAPPED_TO_KOTLIN") - private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): ReceivedSessionMessage { - val session = receiveRequest.session - fun getReceivedMessage(): ReceivedSessionMessage? = session.receivedMessages.poll() + private fun waitForMessage( + receiveRequest: ReceiveRequest, + userReceiveType: Class<*>?): ReceivedSessionMessage { + val receivedMessage = receiveRequest.suspendAndExpectReceive() + return receivedMessage.confirmReceiveType(receiveRequest, userReceiveType) + } - val polledMessage = getReceivedMessage() - val receivedMessage = if (polledMessage != null) { - if (receiveRequest is SendAndReceive) { + @Suspendable + private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { + fun pollForMessage() = session.receivedMessages.poll() + + val polledMessage = pollForMessage() + return if (polledMessage != null) { + if (this is SendAndReceive) { // We've already received a message but we suspend so that the send can be performed - suspend(receiveRequest) + suspend(this) } polledMessage } else { // Suspend while we wait for a receive - suspend(receiveRequest) - getReceivedMessage() ?: - throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " + - "got nothing for $receiveRequest") - } - - if (receiveRequest.receiveType.isInstance(receivedMessage.message)) { - return receivedMessage as ReceivedSessionMessage - } else if (receivedMessage.message is SessionEnd) { - openSessions.values.remove(session) - if (receivedMessage.message.errorResponse != null) { - (receivedMessage.message.errorResponse as java.lang.Throwable).fillInStackTrace() - throw receivedMessage.message.errorResponse - } else { - throw FlowSessionException("${session.state.sendToParty} has ended their flow but we were expecting " + - "to receive ${receiveRequest.receiveType.simpleName} from them") - } - } else { - throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " + - "${receivedMessage.message} for $receiveRequest") + suspend(this) + pollForMessage() ?: + throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") } } + private fun ReceivedSessionMessage<*>.confirmReceiveType( + receiveRequest: ReceiveRequest, + userReceiveType: Class<*>?): ReceivedSessionMessage { + val session = receiveRequest.session + val receiveType = receiveRequest.receiveType + if (receiveType.isInstance(message)) { + @Suppress("UNCHECKED_CAST") + return this as ReceivedSessionMessage + } else if (message is SessionEnd) { + openSessions.values.remove(session) + if (message is ErrorSessionEnd) { + session.erroredEnd(message) + } else { + val expectedType = userReceiveType?.name ?: receiveType.simpleName + throw FlowSessionException("Counterparty flow on ${session.state.sendToParty} has completed without " + + "sending a $expectedType") + } + } else { + throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest") + } + } + + private fun FlowSession.erroredEnd(end: ErrorSessionEnd): Nothing { + if (end.errorResponse != null) { + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + (end.errorResponse as java.lang.Throwable).fillInStackTrace() + throw end.errorResponse + } else { + throw FlowSessionException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") + } + } + @Suspendable private fun suspend(ioRequest: FlowIORequest) { // We have to pass the thread local database transaction across via a transient field as the fiber park // swaps them out. txTrampoline = TransactionManager.currentOrNull() StrandLocalTransactionManager.setThreadLocalTx(null) - if (ioRequest is SessionedFlowIORequest) - ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>) + if (ioRequest is WaitingRequest) + waitingForResponse = ioRequest var exceptionDuringSuspend: Throwable? = null parkAndSerialize { fiber, serializer -> diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt index c1811c98b8..246d0cf07f 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -7,20 +7,10 @@ import net.corda.core.utilities.UntrustworthyData interface SessionMessage -interface ExistingSessionMessage : SessionMessage { - val recipientSessionId: Long -} - data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage -interface SessionInitResponse : ExistingSessionMessage - -data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse { - override val recipientSessionId: Long get() = initiatorSessionId -} - -data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse { - override val recipientSessionId: Long get() = initiatorSessionId +interface ExistingSessionMessage : SessionMessage { + val recipientSessionId: Long } data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { @@ -29,7 +19,16 @@ data class SessionData(override val recipientSessionId: Long, val payload: Any) } } -data class SessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : ExistingSessionMessage +interface SessionInitResponse : ExistingSessionMessage { + val initiatorSessionId: Long + override val recipientSessionId: Long get() = initiatorSessionId +} +data class SessionConfirm(override val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse +data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse + +interface SessionEnd : ExistingSessionMessage +data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd +data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd data class ReceivedSessionMessage(val sender: Party, val message: M) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 4ed04014fc..d1bb24322a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -164,13 +164,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, // Observe the stream of committed, validated transactions and resume fibers that are waiting for them. serviceHub.storageService.validatedTransactions.updates.subscribe { stx -> val hash = stx.id - val flows: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } - if (flows.isNotEmpty()) { + val fibers: Set> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } + if (fibers.isNotEmpty()) { executor.executeASAP { - for (flow in flows) { - logger.info("Resuming ${flow.id} because it was waiting for tx ${flow.waitingForLedgerCommitOf!!} which is now committed.") - flow.waitingForLedgerCommitOf = null - resumeFiber(flow) + for (fiber in fibers) { + fiber.logger.info("Transaction $hash has committed to the ledger, resuming") + fiber.waitingForResponse = null + resumeFiber(fiber) } } } @@ -239,19 +239,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } - val waitingForHash = fiber.waitingForLedgerCommitOf - if (fiber.openSessions.values.any { it.waitingForResponse }) { - fiber.logger.info("Restored, pending on receive") - } else if (waitingForHash != null) { - val stx = databaseTransaction(database) { - serviceHub.storageService.validatedTransactions.getTransaction(waitingForHash) - } - if (stx != null) { - fiber.logger.info("Resuming fiber as tx $waitingForHash has committed.") - resumeFiber(fiber) + val waitingForResponse = fiber.waitingForResponse + if (waitingForResponse != null) { + if (waitingForResponse is WaitForLedgerCommit) { + val stx = databaseTransaction(database) { + serviceHub.storageService.validatedTransactions.getTransaction(waitingForResponse.hash) + } + if (stx != null) { + fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed.") + fiber.waitingForResponse = null + resumeFiber(fiber) + } else { + fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}") + mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) } + } } else { - fiber.logger.info("Restored, pending on ledger commit of $waitingForHash") - mutex.locked { fibersWaitingForLedgerCommit.put(waitingForHash, fiber) } + fiber.logger.info("Restored, pending on receive") } } else { resumeFiber(fiber) @@ -275,15 +278,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { val session = openSessions[message.recipientSessionId] if (session != null) { - session.fiber.logger.trace { "Received $message on $session" } + session.fiber.logger.trace { "Received $message on $session from $sender" } if (message is SessionEnd) { openSessions.remove(message.recipientSessionId) } session.receivedMessages += ReceivedSessionMessage(sender, message) - if (session.waitingForResponse) { - // We only want to resume once, so immediately reset the flag. - session.waitingForResponse = false + if (resumeOnMessage(message, session)) { + // It's important that we reset here and not after the fiber's resumed, in case we receive another message + // before then. + session.fiber.waitingForResponse = null updateCheckpoint(session.fiber) + session.fiber.logger.debug { "About to resume due to $message" } resumeFiber(session.fiber) } } else { @@ -291,7 +296,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, if (peerParty != null) { if (message is SessionConfirm) { logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId, null)) + sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) } else { logger.trace { "Ignoring session end message for already closed session: $message" } } @@ -301,6 +306,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } + // We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger + // commit but a counterparty flow has ended with an error (in which case our flow also has to end) + private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSession): Boolean { + val waitingForResponse = session.fiber.waitingForResponse + return (waitingForResponse as? ReceiveRequest<*>)?.session === session || + waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd + } + private fun onSessionInit(sessionInit: SessionInit, sender: Party) { logger.trace { "Received $sessionInit $sender" } val otherPartySessionId = sessionInit.initiatorSessionId @@ -379,14 +392,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, processIORequest(ioRequest) decrementLiveFibers() } - fiber.actionOnEnd = { errorResponse: Pair? -> + fiber.actionOnEnd = { exception, propagated -> try { fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE mutex.locked { stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } notifyChangeObservers(fiber, AddOrRemove.REMOVE) } - endAllFiberSessions(fiber, errorResponse) + endAllFiberSessions(fiber, exception, propagated) } finally { fiber.commitTransaction() decrementLiveFibers() @@ -401,10 +414,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: Pair?) { + private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, exception: Throwable?, propagated: Boolean) { openSessions.values.removeIf { session -> if (session.fiber == fiber) { - session.endSession(errorResponse) + session.endSession(exception, propagated) true } else { false @@ -412,22 +425,21 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun FlowSession.endSession(errorResponse: Pair?) { + private fun FlowSession.endSession(exception: Throwable?, propagated: Boolean) { val initiatedState = state as? Initiated ?: return - val propagatedException = errorResponse?.let { - val (exception, propagated) = it - if (propagated) { - // This exception was propagated to us. We only propagate it down the invocation chain to the flow that - // initiated us, not to flows we've started sessions with. - if (initiatingParty != null) exception else null + val sessionEnd = if (exception == null) { + NormalSessionEnd(initiatedState.peerSessionId) + } else { + val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { + // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only + // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with. + exception } else { - exception // Our local flow threw the exception so propagate it + null } + ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) } - sendSessionMessage( - initiatedState.peerParty, - SessionEnd(initiatedState.peerSessionId, propagatedException), - fiber) + sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) recentlyClosedSessions[ourSessionId] = initiatedState.peerParty } @@ -570,10 +582,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val flow: FlowLogic<*>, val ourSessionId: Long, val initiatingParty: Party?, - var state: FlowSessionState, - @Volatile var waitingForResponse: Boolean = false - ) { - val receivedMessages = ConcurrentLinkedQueue>() + var state: FlowSessionState) + { + val receivedMessages = ConcurrentLinkedQueue>() val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> } } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt index f7a26dfcf7..1c24c70f82 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt @@ -8,7 +8,6 @@ import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DummyState import net.corda.core.contracts.issuedBy import net.corda.core.crypto.Party -import net.corda.core.crypto.SecureHash import net.corda.core.crypto.generateKeyPair import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic @@ -21,9 +20,9 @@ import net.corda.core.random63BitValue import net.corda.core.rootCause import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.deserialize -import net.corda.core.utilities.unwrap import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.unwrap import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow import net.corda.flows.FinalityFlow @@ -36,6 +35,7 @@ import net.corda.testing.expectEvents import net.corda.testing.initiateSingleShotFlow import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer +import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.sequence @@ -49,10 +49,11 @@ import rx.Observable import java.util.* import kotlin.reflect.KClass import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertTrue class StateMachineManagerTests { - private val net = MockNetwork(servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()) + private val net = MockNetwork(servicePeerAllocationStrategy = RoundRobin()) private val sessionTransfers = ArrayList() private lateinit var node1: MockNode private lateinit var node2: MockNode @@ -102,7 +103,7 @@ class StateMachineManagerTests { @Test fun `exception while fiber suspended`() { - node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) } + node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) } val flow = ReceiveFlow(node2.info.legalIdentity) val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl // Before the flow runs change the suspend action to throw an exception @@ -128,8 +129,7 @@ class StateMachineManagerTests { @Test fun `flow restarted just after receiving payload`() { node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } - val payload = random63BitValue() - node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity)) + node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity)) // We push through just enough messages to get only the payload sent node2.pumpReceive() @@ -138,7 +138,7 @@ class StateMachineManagerTests { node2.stop() net.runNetwork() val restoredFlow = node2.restartAndGetRestoredFlow(node1) - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") } @Test @@ -178,15 +178,14 @@ class StateMachineManagerTests { @Test fun `flow loaded from checkpoint will respond to messages from before start`() { - val payload = random63BitValue() - node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) } + node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) } node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow // Make sure the add() has finished initial processing. node2.smm.executor.flush() node2.disableDBCloseOnStop() node2.stop() // kill receiver val restoredFlow = node2.restartAndGetRestoredFlow(node1) - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") } @Test @@ -245,7 +244,7 @@ class StateMachineManagerTests { net.runNetwork() node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } - val payload = random63BitValue() + val payload = "Hello World" node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity)) net.runNetwork() val node2Flow = node2.getSingleFlow().first @@ -256,14 +255,14 @@ class StateMachineManagerTests { assertSessionTransfers(node2, node1 sent sessionInit(SendFlow::class, payload) to node2, node2 sent sessionConfirm to node1, - node1 sent sessionEnd() to node2 + node1 sent normalEnd to node2 //There's no session end from the other flows as they're manually suspended ) assertSessionTransfers(node3, node1 sent sessionInit(SendFlow::class, payload) to node3, node3 sent sessionConfirm to node1, - node1 sent sessionEnd() to node3 + node1 sent normalEnd to node3 //There's no session end from the other flows as they're manually suspended ) @@ -275,8 +274,8 @@ class StateMachineManagerTests { fun `receiving from multiple parties`() { val node3 = net.createNode(node1.info.address) net.runNetwork() - val node2Payload = random63BitValue() - val node3Payload = random63BitValue() + val node2Payload = "Test 1" + val node3Payload = "Test 2" node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) } node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) } val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating() @@ -290,14 +289,14 @@ class StateMachineManagerTests { node1 sent sessionInit(ReceiveFlow::class) to node2, node2 sent sessionConfirm to node1, node2 sent sessionData(node2Payload) to node1, - node2 sent sessionEnd() to node1 + node2 sent normalEnd to node1 ) assertSessionTransfers(node3, node1 sent sessionInit(ReceiveFlow::class) to node3, node3 sent sessionConfirm to node1, node3 sent sessionData(node3Payload) to node1, - node3 sent sessionEnd() to node1 + node3 sent normalEnd to node1 ) } @@ -313,7 +312,7 @@ class StateMachineManagerTests { node2 sent sessionData(20L) to node1, node1 sent sessionData(11L) to node2, node2 sent sessionData(21L) to node1, - node1 sent sessionEnd() to node2 + node1 sent normalEnd to node2 ) } @@ -321,14 +320,14 @@ class StateMachineManagerTests { fun `different notaries are picked when addressing shared notary identity`() { assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity) node1.services.startFlow(CashIssueFlow( - DOLLARS(2000), + 2000.DOLLARS, OpaqueBytes.of(0x01), node1.info.legalIdentity, notary1.info.notaryIdentity)) // We pay a couple of times, the notary picking should go round robin for (i in 1 .. 3) { node1.services.startFlow(CashPaymentFlow( - DOLLARS(500).issuedBy(node1.info.legalIdentity.ref(0x01)), + 500.DOLLARS.issuedBy(node1.info.legalIdentity.ref(0x01)), node2.info.legalIdentity)) net.runNetwork() } @@ -336,7 +335,7 @@ class StateMachineManagerTests { val party1Info = notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!! assert(party1Info is PartyInfo.Service) val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!) - assert(notary1Address is InMemoryMessagingNetwork.ServiceHandle) + assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java) assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!)) sessionTransfers.expectEvents(isStrict = false) { sequence( @@ -368,12 +367,38 @@ class StateMachineManagerTests { }, expect(match = { it.message is SessionConfirm }) { it.message as SessionConfirm - require(it.from == notary1.id) + assertEquals(it.from, notary1.id) } ) } } + @Test + fun `other side ends before doing expected send`() { + node2.services.registerFlowInitiator(ReceiveFlow::class) { NoOpFlow() } + val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture + net.runNetwork() + assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + resultFuture.getOrThrow() + }.withMessageContaining(String::class.java.name) + } + + @Test + fun `non-FlowException thrown on other side`() { + node2.services.registerFlowInitiator(ReceiveFlow::class) { ExceptionFlow { Exception("evil bug!") } } + val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture + net.runNetwork() + val exceptionResult = assertFailsWith(FlowSessionException::class) { + resultFuture.getOrThrow() + } + assertThat(exceptionResult.message).doesNotContain("evil bug!") + assertSessionTransfers( + node1 sent sessionInit(ReceiveFlow::class) to node2, + node2 sent sessionConfirm to node1, + node2 sent erroredEnd() to node1 + ) + } + @Test fun `FlowException thrown on other side`() { val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) { @@ -384,7 +409,7 @@ class StateMachineManagerTests { assertThatExceptionOfType(MyFlowException::class.java) .isThrownBy { receivingFiber.resultFuture.getOrThrow() } .withMessage("Nothing useful") - .withStackTraceContaining("ReceiveFlow") // Make sure the stack trace is that of the receiving flow + .withStackTraceContaining(ReceiveFlow::class.java.name) // Make sure the stack trace is that of the receiving flow databaseTransaction(node2.database) { assertThat(node2.checkpointStorage.checkpoints()).isEmpty() } @@ -394,10 +419,10 @@ class StateMachineManagerTests { assertSessionTransfers( node1 sent sessionInit(ReceiveFlow::class) to node2, node2 sent sessionConfirm to node1, - node2 sent sessionEnd(errorFlow.exceptionThrown) to node1 + node2 sent erroredEnd(errorFlow.exceptionThrown) to node1 ) // Make sure the original stack trace isn't sent down the wire - assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty() + assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() } @Test @@ -450,7 +475,7 @@ class StateMachineManagerTests { node1 sent sessionInit(ReceiveFlow::class) to node2, node2 sent sessionConfirm to node1, node2 sent sessionData("Hello") to node1, - node1 sent sessionEnd() to node2 // Unexpected session-end + node1 sent erroredEnd() to node2 ) } @@ -496,11 +521,29 @@ class StateMachineManagerTests { ptx.signWith(node1.services.legalIdentityKey) val stx = ptx.toSignedTransaction() - val future1 = node2.services.startFlow(WaitingFlows.Waiter(stx.id)).resultFuture - val future2 = node1.services.startFlow(WaitingFlows.Committer(stx, node2.info.legalIdentity)).resultFuture + val committerFiber = node1 + .initiateSingleShotFlow(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) } + .map { it.stateMachine } + val waiterStx = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture net.runNetwork() - future1.getOrThrow() - future2.getOrThrow() + assertThat(waiterStx.getOrThrow()).isEqualTo(committerFiber.getOrThrow().resultFuture.getOrThrow()) + } + + @Test + fun `committer throws exception before calling the finality flow`() { + val ptx = TransactionBuilder(notary = notary1.info.notaryIdentity) + ptx.addOutputState(DummyState()) + ptx.signWith(node1.services.legalIdentityKey) + val stx = ptx.toSignedTransaction() + + node1.services.registerFlowInitiator(WaitingFlows.Waiter::class) { + WaitingFlows.Committer(it) { throw Exception("Error") } + } + val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture + net.runNetwork() + assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + waiter.getOrThrow() + } } @@ -522,12 +565,10 @@ class StateMachineManagerTests { } private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload) - private val sessionConfirm = SessionConfirm(0, 0) - private fun sessionData(payload: Any) = SessionData(0, payload) - - private fun sessionEnd(error: FlowException? = null) = SessionEnd(0, error) + private val normalEnd = NormalSessionEnd(0) + private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) private fun assertSessionTransfers(vararg expected: SessionTransfer) { assertThat(sessionTransfers).containsExactly(*expected) @@ -557,7 +598,8 @@ class StateMachineManagerTests { is SessionData -> message.copy(recipientSessionId = 0) is SessionInit -> message.copy(initiatorSessionId = 0) is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0) - is SessionEnd -> message.copy(recipientSessionId = 0) + is NormalSessionEnd -> message.copy(recipientSessionId = 0) + is ErrorSessionEnd -> message.copy(recipientSessionId = 0) else -> message } } @@ -578,7 +620,7 @@ class StateMachineManagerTests { } - private class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic() { + private class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic() { init { require(otherParties.isNotEmpty()) } @@ -595,11 +637,11 @@ class StateMachineManagerTests { require(otherParties.isNotEmpty()) } - @Transient var receivedPayloads: List = emptyList() + @Transient var receivedPayloads: List = emptyList() @Suspendable override fun call() { - receivedPayloads = otherParties.map { receive(it).unwrap { it } } + receivedPayloads = otherParties.map { receive(it).unwrap { it } } if (nonTerminating) { Fiber.park() } @@ -630,23 +672,26 @@ class StateMachineManagerTests { } } - private class MyFlowException(message: String) : FlowException(message) { + private class MyFlowException(override val message: String) : FlowException() { override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message - override fun hashCode(): Int = message?.hashCode() ?: 31 + override fun hashCode(): Int = message.hashCode() } private object WaitingFlows { - class Waiter(private val hash: SecureHash) : FlowLogic() { + class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic() { @Suspendable - override fun call() { - waitForLedgerCommit(hash) + override fun call(): SignedTransaction { + send(otherParty, stx) + return waitForLedgerCommit(stx.id) } } - class Committer(private val stx: SignedTransaction, private val otherParty: Party) : FlowLogic() { + class Committer(val otherParty: Party, val throwException: (() -> Exception)? = null) : FlowLogic() { @Suspendable - override fun call() { - subFlow(FinalityFlow(stx, setOf(otherParty))) + override fun call(): SignedTransaction { + val stx = receive(otherParty).unwrap { it } + if (throwException != null) throw throwException.invoke() + return subFlow(FinalityFlow(stx, setOf(otherParty))).single() } } } diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt index 89d05e3626..aaeb831d48 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -282,6 +282,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, * parameter set to -1 (the default) which simply runs as many rounds as necessary to result in network * stability (no nodes sent any messages in the last round). */ + @JvmOverloads fun runNetwork(rounds: Int = -1) { check(!networkSendManuallyPumped) fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) } @@ -324,6 +325,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, * Sets up a network with the requested number of nodes (defaulting to two), with one or more service nodes that * run a notary, network map, any oracles etc. Can't be combined with [createTwoNodes]. */ + @JvmOverloads fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes { require(nodes.isEmpty()) val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type)