End flow if waiting for ledger commit and committer flow errors

This commit is contained in:
Shams Asari 2017-02-08 16:00:39 +00:00 committed by Chris Rankin
parent 3c0d6fd14f
commit 71182ec8c1
7 changed files with 244 additions and 152 deletions

View File

@ -203,7 +203,7 @@ abstract class FlowLogic<out T> {
val theirs = subLogic.progressTracker
if (ours != null && theirs != null) {
if (ours.currentStep == ProgressTracker.UNSTARTED) {
logger.warn("ProgressTracker has not been started for $this")
logger.warn("ProgressTracker has not been started")
ours.nextStep()
}
ours.setChildProgressTracker(ours.currentStep, theirs)

View File

@ -10,6 +10,8 @@ interface FlowIORequest {
val stackTraceInCaseOfProblems: StackSnapshot
}
interface WaitingRequest : FlowIORequest
interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSession
}
@ -18,7 +20,7 @@ interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage
}
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest {
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T>
}
@ -40,7 +42,7 @@ data class SendOnly(override val session: FlowSession, override val message: Ses
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}

View File

@ -51,7 +51,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Transient override lateinit var serviceHub: ServiceHubInternal
@Transient internal lateinit var database: Database
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: (Pair<FlowException, Boolean>?) -> Unit
@Transient internal lateinit var actionOnEnd: (Throwable?, Boolean) -> Unit
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var txTrampoline: Transaction? = null
@ -76,7 +76,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// This state IS serialised, as we need it to know what the fiber is waiting for.
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
internal var waitingForLedgerCommitOf: SecureHash? = null
internal var waitingForResponse: WaitingRequest? = null
init {
logic.stateMachine = this
@ -91,11 +91,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} catch (e: FlowException) {
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
val propagated = e.stackTrace[0].className == javaClass.name
actionOnEnd(Pair(e, propagated))
actionOnEnd(e, propagated)
_resultFuture?.setException(e)
return
} catch (t: Throwable) {
actionOnEnd(null)
actionOnEnd(t, false)
_resultFuture?.setException(t)
throw ExecutionException(t)
}
@ -105,7 +105,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
.filter { it.state is FlowSessionState.Initiating }
.forEach { it.waitForConfirmation() }
// This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd(null)
actionOnEnd(null, false)
_resultFuture?.set(result)
}
@ -136,10 +136,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val session = getConfirmedSession(otherParty, sessionFlow)
return if (session == null) {
val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true)
// Only do a receive here as the session init has carried the payload
receiveInternal<SessionData>(startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true))
receiveInternal<SessionData>(newSession, receiveType)
} else {
sendAndReceiveInternal<SessionData>(session, createSessionData(session, payload))
sendAndReceiveInternal<SessionData>(session, createSessionData(session, payload), receiveType)
}.checkPayloadIs(receiveType)
}
@ -147,8 +148,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val session = getConfirmedSession(otherParty, sessionFlow) ?: startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true)
return receiveInternal<SessionData>(session).checkPayloadIs(receiveType)
val session = getConfirmedSession(otherParty, sessionFlow) ?:
startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true)
return receiveInternal<SessionData>(session, receiveType).checkPayloadIs(receiveType)
}
@Suspendable
@ -167,7 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/
@Suspendable
private fun FlowSession.waitForConfirmation() {
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this)
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this, null)
if (sessionInitResponse is SessionConfirm) {
state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
} else {
@ -178,12 +180,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction {
waitingForLedgerCommitOf = hash
logger.info("Waiting for transaction $hash to commit")
suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>))
logger.info("Transaction $hash has committed to the ledger, resuming")
val stx = serviceHub.storageService.validatedTransactions.getTransaction(hash)
return stx ?: throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
if (stx != null) return stx
// If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message is ErrorSessionEnd) {
session.erroredEnd(receivedMessage.message)
}
}
}
throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
}
private fun createSessionData(session: FlowSession, payload: Any): SessionData {
@ -200,14 +209,17 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
suspend(SendOnly(session, message))
}
private inline fun <reified M : ExistingSessionMessage> receiveInternal(session: FlowSession): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
private inline fun <reified M : ExistingSessionMessage> receiveInternal(
session: FlowSession,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(ReceiveOnly(session, M::class.java), userReceiveType)
}
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
session: FlowSession,
message: SessionMessage): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
message: SessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(SendAndReceive(session, message, M::class.java), userReceiveType)
}
@Suspendable
@ -241,51 +253,72 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
@Suppress("UNCHECKED_CAST", "PLATFORM_CLASS_MAPPED_TO_KOTLIN")
private fun <M : ExistingSessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
val session = receiveRequest.session
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = session.receivedMessages.poll()
private fun <M : ExistingSessionMessage> waitForMessage(
receiveRequest: ReceiveRequest<M>,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
val receivedMessage = receiveRequest.suspendAndExpectReceive()
return receivedMessage.confirmReceiveType(receiveRequest, userReceiveType)
}
val polledMessage = getReceivedMessage()
val receivedMessage = if (polledMessage != null) {
if (receiveRequest is SendAndReceive) {
@Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
fun pollForMessage() = session.receivedMessages.poll()
val polledMessage = pollForMessage()
return if (polledMessage != null) {
if (this is SendAndReceive) {
// We've already received a message but we suspend so that the send can be performed
suspend(receiveRequest)
suspend(this)
}
polledMessage
} else {
// Suspend while we wait for a receive
suspend(receiveRequest)
getReceivedMessage() ?:
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " +
"got nothing for $receiveRequest")
}
if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
return receivedMessage as ReceivedSessionMessage<M>
} else if (receivedMessage.message is SessionEnd) {
openSessions.values.remove(session)
if (receivedMessage.message.errorResponse != null) {
(receivedMessage.message.errorResponse as java.lang.Throwable).fillInStackTrace()
throw receivedMessage.message.errorResponse
} else {
throw FlowSessionException("${session.state.sendToParty} has ended their flow but we were expecting " +
"to receive ${receiveRequest.receiveType.simpleName} from them")
}
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " +
"${receivedMessage.message} for $receiveRequest")
suspend(this)
pollForMessage() ?:
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this")
}
}
private fun <M : ExistingSessionMessage> ReceivedSessionMessage<*>.confirmReceiveType(
receiveRequest: ReceiveRequest<M>,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
val session = receiveRequest.session
val receiveType = receiveRequest.receiveType
if (receiveType.isInstance(message)) {
@Suppress("UNCHECKED_CAST")
return this as ReceivedSessionMessage<M>
} else if (message is SessionEnd) {
openSessions.values.remove(session)
if (message is ErrorSessionEnd) {
session.erroredEnd(message)
} else {
val expectedType = userReceiveType?.name ?: receiveType.simpleName
throw FlowSessionException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending a $expectedType")
}
} else {
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest")
}
}
private fun FlowSession.erroredEnd(end: ErrorSessionEnd): Nothing {
if (end.errorResponse != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(end.errorResponse as java.lang.Throwable).fillInStackTrace()
throw end.errorResponse
} else {
throw FlowSessionException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
}
}
@Suspendable
private fun suspend(ioRequest: FlowIORequest) {
// We have to pass the thread local database transaction across via a transient field as the fiber park
// swaps them out.
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
if (ioRequest is SessionedFlowIORequest)
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
if (ioRequest is WaitingRequest)
waitingForResponse = ioRequest
var exceptionDuringSuspend: Throwable? = null
parkAndSerialize { fiber, serializer ->

View File

@ -7,20 +7,10 @@ import net.corda.core.utilities.UntrustworthyData
interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
@ -29,7 +19,16 @@ data class SessionData(override val recipientSessionId: Long, val payload: Any)
}
}
data class SessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : ExistingSessionMessage
interface SessionInitResponse : ExistingSessionMessage {
val initiatorSessionId: Long
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionConfirm(override val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
interface SessionEnd : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)

View File

@ -164,13 +164,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// Observe the stream of committed, validated transactions and resume fibers that are waiting for them.
serviceHub.storageService.validatedTransactions.updates.subscribe { stx ->
val hash = stx.id
val flows: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) }
if (flows.isNotEmpty()) {
val fibers: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) }
if (fibers.isNotEmpty()) {
executor.executeASAP {
for (flow in flows) {
logger.info("Resuming ${flow.id} because it was waiting for tx ${flow.waitingForLedgerCommitOf!!} which is now committed.")
flow.waitingForLedgerCommitOf = null
resumeFiber(flow)
for (fiber in fibers) {
fiber.logger.info("Transaction $hash has committed to the ledger, resuming")
fiber.waitingForResponse = null
resumeFiber(fiber)
}
}
}
@ -239,19 +239,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) {
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
val waitingForHash = fiber.waitingForLedgerCommitOf
if (fiber.openSessions.values.any { it.waitingForResponse }) {
fiber.logger.info("Restored, pending on receive")
} else if (waitingForHash != null) {
val stx = databaseTransaction(database) {
serviceHub.storageService.validatedTransactions.getTransaction(waitingForHash)
}
if (stx != null) {
fiber.logger.info("Resuming fiber as tx $waitingForHash has committed.")
resumeFiber(fiber)
val waitingForResponse = fiber.waitingForResponse
if (waitingForResponse != null) {
if (waitingForResponse is WaitForLedgerCommit) {
val stx = databaseTransaction(database) {
serviceHub.storageService.validatedTransactions.getTransaction(waitingForResponse.hash)
}
if (stx != null) {
fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed.")
fiber.waitingForResponse = null
resumeFiber(fiber)
} else {
fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}")
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) }
}
} else {
fiber.logger.info("Restored, pending on ledger commit of $waitingForHash")
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForHash, fiber) }
fiber.logger.info("Restored, pending on receive")
}
} else {
resumeFiber(fiber)
@ -275,15 +278,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
val session = openSessions[message.recipientSessionId]
if (session != null) {
session.fiber.logger.trace { "Received $message on $session" }
session.fiber.logger.trace { "Received $message on $session from $sender" }
if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += ReceivedSessionMessage(sender, message)
if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
if (resumeOnMessage(message, session)) {
// It's important that we reset here and not after the fiber's resumed, in case we receive another message
// before then.
session.fiber.waitingForResponse = null
updateCheckpoint(session.fiber)
session.fiber.logger.debug { "About to resume due to $message" }
resumeFiber(session.fiber)
}
} else {
@ -291,7 +296,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (peerParty != null) {
if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId, null))
sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId))
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
@ -301,6 +306,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
}
// We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSession): Boolean {
val waitingForResponse = session.fiber.waitingForResponse
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
}
private fun onSessionInit(sessionInit: SessionInit, sender: Party) {
logger.trace { "Received $sessionInit $sender" }
val otherPartySessionId = sessionInit.initiatorSessionId
@ -379,14 +392,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
processIORequest(ioRequest)
decrementLiveFibers()
}
fiber.actionOnEnd = { errorResponse: Pair<FlowException, Boolean>? ->
fiber.actionOnEnd = { exception, propagated ->
try {
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(fiber, AddOrRemove.REMOVE)
}
endAllFiberSessions(fiber, errorResponse)
endAllFiberSessions(fiber, exception, propagated)
} finally {
fiber.commitTransaction()
decrementLiveFibers()
@ -401,10 +414,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
}
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: Pair<FlowException, Boolean>?) {
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, exception: Throwable?, propagated: Boolean) {
openSessions.values.removeIf { session ->
if (session.fiber == fiber) {
session.endSession(errorResponse)
session.endSession(exception, propagated)
true
} else {
false
@ -412,22 +425,21 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
}
private fun FlowSession.endSession(errorResponse: Pair<FlowException, Boolean>?) {
private fun FlowSession.endSession(exception: Throwable?, propagated: Boolean) {
val initiatedState = state as? Initiated ?: return
val propagatedException = errorResponse?.let {
val (exception, propagated) = it
if (propagated) {
// This exception was propagated to us. We only propagate it down the invocation chain to the flow that
// initiated us, not to flows we've started sessions with.
if (initiatingParty != null) exception else null
val sessionEnd = if (exception == null) {
NormalSessionEnd(initiatedState.peerSessionId)
} else {
val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
// Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
// pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with.
exception
} else {
exception // Our local flow threw the exception so propagate it
null
}
ErrorSessionEnd(initiatedState.peerSessionId, errorResponse)
}
sendSessionMessage(
initiatedState.peerParty,
SessionEnd(initiatedState.peerSessionId, propagatedException),
fiber)
sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber)
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
}
@ -570,10 +582,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val flow: FlowLogic<*>,
val ourSessionId: Long,
val initiatingParty: Party?,
var state: FlowSessionState,
@Volatile var waitingForResponse: Boolean = false
) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
var state: FlowSessionState)
{
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<*>>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
}
}

View File

@ -8,7 +8,6 @@ import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState
import net.corda.core.contracts.issuedBy
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
@ -21,9 +20,9 @@ import net.corda.core.random63BitValue
import net.corda.core.rootCause
import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize
import net.corda.core.utilities.unwrap
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.unwrap
import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow
import net.corda.flows.FinalityFlow
@ -36,6 +35,7 @@ import net.corda.testing.expectEvents
import net.corda.testing.initiateSingleShotFlow
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence
@ -49,10 +49,11 @@ import rx.Observable
import java.util.*
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class StateMachineManagerTests {
private val net = MockNetwork(servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin())
private val net = MockNetwork(servicePeerAllocationStrategy = RoundRobin())
private val sessionTransfers = ArrayList<SessionTransfer>()
private lateinit var node1: MockNode
private lateinit var node2: MockNode
@ -102,7 +103,7 @@ class StateMachineManagerTests {
@Test
fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) }
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) }
val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
@ -128,8 +129,7 @@ class StateMachineManagerTests {
@Test
fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue()
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity))
// We push through just enough messages to get only the payload sent
node2.pumpReceive()
@ -138,7 +138,7 @@ class StateMachineManagerTests {
node2.stop()
net.runNetwork()
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Test
@ -178,15 +178,14 @@ class StateMachineManagerTests {
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue()
node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) }
node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
node2.disableDBCloseOnStop()
node2.stop() // kill receiver
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Test
@ -245,7 +244,7 @@ class StateMachineManagerTests {
net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue()
val payload = "Hello World"
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork()
val node2Flow = node2.getSingleFlow<ReceiveFlow>().first
@ -256,14 +255,14 @@ class StateMachineManagerTests {
assertSessionTransfers(node2,
node1 sent sessionInit(SendFlow::class, payload) to node2,
node2 sent sessionConfirm to node1,
node1 sent sessionEnd() to node2
node1 sent normalEnd to node2
//There's no session end from the other flows as they're manually suspended
)
assertSessionTransfers(node3,
node1 sent sessionInit(SendFlow::class, payload) to node3,
node3 sent sessionConfirm to node1,
node1 sent sessionEnd() to node3
node1 sent normalEnd to node3
//There's no session end from the other flows as they're manually suspended
)
@ -275,8 +274,8 @@ class StateMachineManagerTests {
fun `receiving from multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val node2Payload = random63BitValue()
val node3Payload = random63BitValue()
val node2Payload = "Test 1"
val node3Payload = "Test 2"
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
@ -290,14 +289,14 @@ class StateMachineManagerTests {
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1
node2 sent normalEnd to node1
)
assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveFlow::class) to node3,
node3 sent sessionConfirm to node1,
node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1
node3 sent normalEnd to node1
)
}
@ -313,7 +312,7 @@ class StateMachineManagerTests {
node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1,
node1 sent sessionEnd() to node2
node1 sent normalEnd to node2
)
}
@ -321,14 +320,14 @@ class StateMachineManagerTests {
fun `different notaries are picked when addressing shared notary identity`() {
assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity)
node1.services.startFlow(CashIssueFlow(
DOLLARS(2000),
2000.DOLLARS,
OpaqueBytes.of(0x01),
node1.info.legalIdentity,
notary1.info.notaryIdentity))
// We pay a couple of times, the notary picking should go round robin
for (i in 1 .. 3) {
node1.services.startFlow(CashPaymentFlow(
DOLLARS(500).issuedBy(node1.info.legalIdentity.ref(0x01)),
500.DOLLARS.issuedBy(node1.info.legalIdentity.ref(0x01)),
node2.info.legalIdentity))
net.runNetwork()
}
@ -336,7 +335,7 @@ class StateMachineManagerTests {
val party1Info = notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!
assert(party1Info is PartyInfo.Service)
val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!)
assert(notary1Address is InMemoryMessagingNetwork.ServiceHandle)
assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java)
assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!))
sessionTransfers.expectEvents(isStrict = false) {
sequence(
@ -368,12 +367,38 @@ class StateMachineManagerTests {
},
expect(match = { it.message is SessionConfirm }) {
it.message as SessionConfirm
require(it.from == notary1.id)
assertEquals(it.from, notary1.id)
}
)
}
}
@Test
fun `other side ends before doing expected send`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name)
}
@Test
fun `non-FlowException thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { ExceptionFlow { Exception("evil bug!") } }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
val exceptionResult = assertFailsWith(FlowSessionException::class) {
resultFuture.getOrThrow()
}
assertThat(exceptionResult.message).doesNotContain("evil bug!")
assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent erroredEnd() to node1
)
}
@Test
fun `FlowException thrown on other side`() {
val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) {
@ -384,7 +409,7 @@ class StateMachineManagerTests {
assertThatExceptionOfType(MyFlowException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
.withMessage("Nothing useful")
.withStackTraceContaining("ReceiveFlow") // Make sure the stack trace is that of the receiving flow
.withStackTraceContaining(ReceiveFlow::class.java.name) // Make sure the stack trace is that of the receiving flow
databaseTransaction(node2.database) {
assertThat(node2.checkpointStorage.checkpoints()).isEmpty()
}
@ -394,10 +419,10 @@ class StateMachineManagerTests {
assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent sessionEnd(errorFlow.exceptionThrown) to node1
node2 sent erroredEnd(errorFlow.exceptionThrown) to node1
)
// Make sure the original stack trace isn't sent down the wire
assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty()
assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
}
@Test
@ -450,7 +475,7 @@ class StateMachineManagerTests {
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent sessionData("Hello") to node1,
node1 sent sessionEnd() to node2 // Unexpected session-end
node1 sent erroredEnd() to node2
)
}
@ -496,11 +521,29 @@ class StateMachineManagerTests {
ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction()
val future1 = node2.services.startFlow(WaitingFlows.Waiter(stx.id)).resultFuture
val future2 = node1.services.startFlow(WaitingFlows.Committer(stx, node2.info.legalIdentity)).resultFuture
val committerFiber = node1
.initiateSingleShotFlow(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) }
.map { it.stateMachine }
val waiterStx = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
net.runNetwork()
future1.getOrThrow()
future2.getOrThrow()
assertThat(waiterStx.getOrThrow()).isEqualTo(committerFiber.getOrThrow().resultFuture.getOrThrow())
}
@Test
fun `committer throws exception before calling the finality flow`() {
val ptx = TransactionBuilder(notary = notary1.info.notaryIdentity)
ptx.addOutputState(DummyState())
ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction()
node1.services.registerFlowInitiator(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) { throw Exception("Error") }
}
val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
waiter.getOrThrow()
}
}
@ -522,12 +565,10 @@ class StateMachineManagerTests {
}
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
private val sessionConfirm = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionEnd(error: FlowException? = null) = SessionEnd(0, error)
private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected)
@ -557,7 +598,8 @@ class StateMachineManagerTests {
is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0)
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0)
is SessionEnd -> message.copy(recipientSessionId = 0)
is NormalSessionEnd -> message.copy(recipientSessionId = 0)
is ErrorSessionEnd -> message.copy(recipientSessionId = 0)
else -> message
}
}
@ -578,7 +620,7 @@ class StateMachineManagerTests {
}
private class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<Unit>() {
private class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
}
@ -595,11 +637,11 @@ class StateMachineManagerTests {
require(otherParties.isNotEmpty())
}
@Transient var receivedPayloads: List<Any> = emptyList()
@Transient var receivedPayloads: List<String> = emptyList()
@Suspendable
override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
receivedPayloads = otherParties.map { receive<String>(it).unwrap { it } }
if (nonTerminating) {
Fiber.park()
}
@ -630,23 +672,26 @@ class StateMachineManagerTests {
}
}
private class MyFlowException(message: String) : FlowException(message) {
private class MyFlowException(override val message: String) : FlowException() {
override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message
override fun hashCode(): Int = message?.hashCode() ?: 31
override fun hashCode(): Int = message.hashCode()
}
private object WaitingFlows {
class Waiter(private val hash: SecureHash) : FlowLogic<Unit>() {
class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call() {
waitForLedgerCommit(hash)
override fun call(): SignedTransaction {
send(otherParty, stx)
return waitForLedgerCommit(stx.id)
}
}
class Committer(private val stx: SignedTransaction, private val otherParty: Party) : FlowLogic<Unit>() {
class Committer(val otherParty: Party, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call() {
subFlow(FinalityFlow(stx, setOf(otherParty)))
override fun call(): SignedTransaction {
val stx = receive<SignedTransaction>(otherParty).unwrap { it }
if (throwException != null) throw throwException.invoke()
return subFlow(FinalityFlow(stx, setOf(otherParty))).single()
}
}
}

View File

@ -282,6 +282,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
* parameter set to -1 (the default) which simply runs as many rounds as necessary to result in network
* stability (no nodes sent any messages in the last round).
*/
@JvmOverloads
fun runNetwork(rounds: Int = -1) {
check(!networkSendManuallyPumped)
fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) }
@ -324,6 +325,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
* Sets up a network with the requested number of nodes (defaulting to two), with one or more service nodes that
* run a notary, network map, any oracles etc. Can't be combined with [createTwoNodes].
*/
@JvmOverloads
fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes {
require(nodes.isEmpty())
val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type)