diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index e31a0cba35..7b31b6a787 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -122,9 +122,14 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, payload: Any, receiveType: Class, sessionProtocol: ProtocolLogic<*>): UntrustworthyData { - val session = getSession(otherParty, sessionProtocol) - val sendSessionData = createSessionData(session, payload) - val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java) + val (session, new) = getSession(otherParty, sessionProtocol, payload) + val receivedSessionData = if (new) { + // Only do a receive here as the session init has carried the payload + receiveInternal(session) + } else { + val sendSessionData = createSessionData(session, payload) + sendAndReceiveInternal(session, sendSessionData) + } return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) } @@ -132,15 +137,18 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, override fun receive(otherParty: Party, receiveType: Class, sessionProtocol: ProtocolLogic<*>): UntrustworthyData { - val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java) + val session = getSession(otherParty, sessionProtocol, null).first + val receivedSessionData = receiveInternal(session) return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) } @Suspendable override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) { - val session = getSession(otherParty, sessionProtocol) - val sendSessionData = createSessionData(session, payload) - sendInternal(session, sendSessionData) + val (session, new) = getSession(otherParty, sessionProtocol, payload) + if (!new) { + // Don't send the payload again if it was already piggy-backed on a session init + sendInternal(session, createSessionData(session, payload)) + } } private fun createSessionData(session: ProtocolSession, payload: Any): SessionData { @@ -155,27 +163,31 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun receiveInternal(session: ProtocolSession, receiveType: Class): T { - return suspendAndExpectReceive(ReceiveOnly(session, receiveType)) + private inline fun receiveInternal(session: ProtocolSession): M { + return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)) + } + + private inline fun sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage): M { + return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) } @Suspendable - private fun sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class): T { - return suspendAndExpectReceive(SendAndReceive(session, message, receiveType)) + private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?): Pair { + val session = openSessions[Pair(sessionProtocol, otherParty)] + return if (session != null) { + Pair(session, false) + } else { + Pair(startNewSession(otherParty, sessionProtocol, firstPayload), true) + } } @Suspendable - private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession { - return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol) - } - - @Suspendable - private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession { + private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?) : ProtocolSession { val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null) openSessions[Pair(sessionProtocol, otherParty)] = session val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name - val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol) - val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java) + val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol, firstPayload) + val sessionInitResponse = sendAndReceiveInternal(session, sessionInit) if (sessionInitResponse is SessionConfirm) { session.otherPartySessionId = sessionInitResponse.initiatedSessionId return session @@ -186,21 +198,26 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): T { + private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): M { fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll() - val receivedMessage = getReceivedMessage() ?: run { - // Suspend while we wait for the receive - receiveRequest.session.waitingForResponse = true + val polledMessage = getReceivedMessage() + val receivedMessage = if (polledMessage != null) { + if (receiveRequest is SendAndReceive) { + // We've already received a message but we suspend so that the send can be performed + suspend(receiveRequest) + } + polledMessage + } else { + // Suspend while we wait for a receive suspend(receiveRequest) - receiveRequest.session.waitingForResponse = false getReceivedMessage() ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest") } if (receivedMessage is SessionEnd) { openSessions.values.remove(receiveRequest.session) - throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended") + throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended on $receiveRequest") } else if (receiveRequest.receiveType.isInstance(receivedMessage)) { return receiveRequest.receiveType.cast(receivedMessage) } else { @@ -213,6 +230,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. txTrampoline = TransactionManager.currentOrNull() StrandLocalTransactionManager.setThreadLocalTx(null) + ioRequest.session.waitingForResponse = true parkAndSerialize { fiber, serializer -> logger.trace { "Suspended on $ioRequest" } // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB @@ -228,6 +246,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, processException(t) } } + ioRequest.session.waitingForResponse = false createTransaction() } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index 86ae7a21ba..246c5f32c3 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -233,6 +233,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val protocol = protocolFactory(otherParty) val psm = createFiber(protocol) val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId) + if (sessionInit.firstPayload != null) { + session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload) + } openSessions[session.ourSessionId] = session psm.openSessions[Pair(protocol, otherParty)] = session updateCheckpoint(psm) @@ -400,7 +403,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val recipientSessionId: Long } - data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage + data class SessionInit(val initiatorSessionId: Long, + val initiatorParty: Party, + val protocolName: String, + val firstPayload: Any?) : SessionMessage interface SessionInitResponse : ExistingSessionMessage diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index d61a0c1026..b5a3041475 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -7,6 +7,7 @@ import com.r3corda.core.contracts.* import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash import com.r3corda.core.days +import com.r3corda.core.map import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.node.services.* import com.r3corda.core.protocols.ProtocolStateMachine @@ -143,15 +144,15 @@ class TwoPartyTradeProtocolTests { // Everything is on this thread so we can now step through the protocol one step at a time. // Seller Alice already sent a message to Buyer Bob. Pump once: - bobNode.pumpReceive(false) + bobNode.pumpReceive() // Bob sends a couple of queries for the dependencies back to Alice. Alice reponds. - aliceNode.pumpReceive(false) - bobNode.pumpReceive(false) - aliceNode.pumpReceive(false) - bobNode.pumpReceive(false) - aliceNode.pumpReceive(false) - bobNode.pumpReceive(false) + aliceNode.pumpReceive() + bobNode.pumpReceive() + aliceNode.pumpReceive() + bobNode.pumpReceive() + aliceNode.pumpReceive() + bobNode.pumpReceive() // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) @@ -164,7 +165,7 @@ class TwoPartyTradeProtocolTests { // Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use. // She will wait around until Bob comes back. - assertThat(aliceNode.pumpReceive(false)).isNotNull() + assertThat(aliceNode.pumpReceive()).isNotNull() // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // that Bob was waiting on before the reboot occurred. @@ -386,7 +387,7 @@ class TwoPartyTradeProtocolTests { private data class RunResult( // The buyer is not created immediately, only when the seller starts running - val buyer: Future>, + val buyer: Future>, val sellerResult: Future, val sellerId: StateMachineRunId ) @@ -394,7 +395,7 @@ class TwoPartyTradeProtocolTests { private fun runBuyerAndSeller(assetToSell: StateAndRef) : RunResult { val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty -> Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java) - } + }.map { it.psm } val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY) val sellerResultFuture = aliceNode.smm.add(seller).resultFuture return RunResult(buyerFuture, sellerResultFuture, seller.psm.id) diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index e0fac2f406..10d30cd5d7 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -4,15 +4,15 @@ import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.crypto.Party -import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolSessionException import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.deserialize import com.r3corda.node.services.persistence.checkpoints -import com.r3corda.node.services.statemachine.StateMachineManager.SessionData -import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage +import com.r3corda.node.services.statemachine.StateMachineManager.* +import com.r3corda.testing.initiateSingleShotProtocol import com.r3corda.testing.node.InMemoryMessagingNetwork +import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat @@ -20,20 +20,25 @@ import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After import org.junit.Before import org.junit.Test +import rx.Observable +import java.util.* +import kotlin.reflect.KClass import kotlin.test.assertEquals import kotlin.test.assertTrue class StateMachineManagerTests { - val net = MockNetwork() - lateinit var node1: MockNode - lateinit var node2: MockNode + private val net = MockNetwork() + private val sessionTransfers = ArrayList() + private lateinit var node1: MockNode + private lateinit var node2: MockNode @Before fun start() { val nodes = net.createTwoNodes() node1 = nodes.first node2 = nodes.second + net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it } net.runNetwork() } @@ -44,14 +49,18 @@ class StateMachineManagerTests { @Test fun `newly added protocol is preserved on restart`() { - node1.smm.add(ProtocolWithoutCheckpoints()) - val restoredProtocol = node1.restartAndGetRestoredProtocol() + node1.smm.add(NoOpProtocol(nonTerminating = true)) + val restoredProtocol = node1.restartAndGetRestoredProtocol() assertThat(restoredProtocol.protocolStarted).isTrue() } @Test fun `protocol can lazily use the serviceHub in its constructor`() { - val protocol = ProtocolWithLazyServiceHub() + val protocol = object : ProtocolLogic() { + val lazyTime by lazy { serviceHub.clock.instant() } + @Suspendable + override fun call() = Unit + } node1.smm.add(protocol) assertThat(protocol.lazyTime).isNotNull() } @@ -62,19 +71,18 @@ class StateMachineManagerTests { val payload = random63BitValue() node1.smm.add(SendProtocol(payload, node2.info.legalIdentity)) - // We push through just enough messages to get only the SessionData sent - // TODO We should be able to give runNetwork a predicate for when to stop - net.runNetwork(2) + // We push through just enough messages to get only the payload sent + node2.pumpReceive() node2.stop() net.runNetwork() - val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) + val restoredProtocol = node2.restartAndGetRestoredProtocol(node1) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) } @Test fun `protocol added before network map does run after init`() { val node3 = net.createNode(node1.info.address) //create vanilla node - val protocol = ProtocolNoBlocking() + val protocol = NoOpProtocol() node3.smm.add(protocol) assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet net.runNetwork() // Allow network map messages to flow @@ -84,13 +92,13 @@ class StateMachineManagerTests { @Test fun `protocol added before network map will be init checkpointed`() { var node3 = net.createNode(node1.info.address) //create vanilla node - val protocol = ProtocolNoBlocking() + val protocol = NoOpProtocol() node3.smm.add(protocol) assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet node3.stop() node3 = net.createNode(node1.info.address, forcedID = node3.id) - val restoredProtocol = node3.getSingleProtocol().first + val restoredProtocol = node3.getSingleProtocol().first assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet net.runNetwork() // Allow network map messages to flow node3.smm.executor.flush() @@ -101,17 +109,16 @@ class StateMachineManagerTests { node3 = net.createNode(node1.info.address, forcedID = node3.id) net.runNetwork() // Allow network map messages to flow node3.smm.executor.flush() - assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty()) + assertTrue(node3.smm.findStateMachines(NoOpProtocol::class.java).isEmpty()) } @Test fun `protocol loaded from checkpoint will respond to messages from before start`() { val payload = random63BitValue() node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) } - val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.legalIdentity) - node2.smm.add(receiveProtocol) // Prepare checkpointed receive protocol + node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol node2.stop() // kill receiver - val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) + val restoredProtocol = node2.restartAndGetRestoredProtocol(node1) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) } @@ -121,41 +128,36 @@ class StateMachineManagerTests { val payload2 = random63BitValue() var sentCount = 0 - var receivedCount = 0 - net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ } - net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ } - val node3 = net.createNode(node1.info.address) - net.runNetwork() + net.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } - var secondProtocol: PingPongProtocol? = null - node3.services.registerProtocolInitiator(PingPongProtocol::class) { - val protocol = PingPongProtocol(it, payload2) - secondProtocol = protocol - protocol - } + val node3 = net.createNode(node1.info.address) + val secondProtocol = node3.initiateSingleShotProtocol(PingPongProtocol::class) { PingPongProtocol(it, payload2) } + net.runNetwork() // Kick off first send and receive node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload)) - assertEquals(1, node2.checkpointStorage.checkpoints().count()) + assertEquals(1, node2.checkpointStorage.checkpoints().size) // Restart node and thus reload the checkpoint and resend the message with same UUID node2.stop() val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val (firstAgain, fut1) = node2b.getSingleProtocol() - net.runNetwork() - assertEquals(1, node2.checkpointStorage.checkpoints().count()) // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. net.runNetwork() + assertEquals(1, node2.checkpointStorage.checkpoints().size) node2b.smm.executor.flush() fut1.get() + + val receivedCount = sessionTransfers.count { it.isPayloadTransfer } // Check protocols completed cleanly and didn't get out of phase assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way - assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages - assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") - assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") + // can't give a precise value as every addMessageHandler re-runs the undelivered messages + assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") + assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended") + assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") - assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") - assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2") + assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") + assertEquals(payload + 1, secondProtocol.get().receivedPayload2, "Received payload does not match the expected second value on Node 2") } @Test @@ -171,6 +173,20 @@ class StateMachineManagerTests { val node3Protocol = node3.getSingleProtocol().first assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload) assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload) + + assertSessionTransfers(node2, + node1 sent sessionInit(node1, SendProtocol::class, payload) to node2, + node2 sent sessionConfirm() to node1, + node1 sent sessionEnd() to node2 + //There's no session end from the other protocols as they're manually suspended + ) + + assertSessionTransfers(node3, + node1 sent sessionInit(node1, SendProtocol::class, payload) to node3, + node3 sent sessionConfirm() to node1, + node1 sent sessionEnd() to node3 + //There's no session end from the other protocols as they're manually suspended + ) } @Test @@ -183,13 +199,39 @@ class StateMachineManagerTests { node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity) node1.smm.add(multiReceiveProtocol) - net.runNetwork(1) // session handshaking - // have the messages arrive in reverse order of receive - node3.pumpReceive(false) - node2.pumpReceive(false) - net.runNetwork() // pump remaining messages + net.runNetwork() assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) + + assertSessionTransfers(node2, + node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2, + node2 sent sessionConfirm() to node1, + node2 sent sessionData(node2Payload) to node1, + node2 sent sessionEnd() to node1 + ) + + assertSessionTransfers(node3, + node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node3, + node3 sent sessionConfirm() to node1, + node3 sent sessionData(node3Payload) to node1, + node3 sent sessionEnd() to node1 + ) + } + + @Test + fun `both sides do a send as their first IO request`() { + node2.services.registerProtocolInitiator(PingPongProtocol::class) { PingPongProtocol(it, 20L) } + node1.smm.add(PingPongProtocol(node2.info.legalIdentity, 10L)) + net.runNetwork() + + assertSessionTransfers( + node1 sent sessionInit(node1, PingPongProtocol::class, 10L) to node2, + node2 sent sessionConfirm() to node1, + node2 sent sessionData(20L) to node1, + node1 sent sessionData(11L) to node2, + node2 sent sessionData(21L) to node1, + node1 sent sessionEnd() to node2 + ) } @Test @@ -198,16 +240,17 @@ class StateMachineManagerTests { val future = node1.smm.add(ReceiveThenSuspendProtocol(node2.info.legalIdentity)).resultFuture net.runNetwork() assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java) + assertSessionTransfers( + node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2, + node2 sent sessionConfirm() to node1, + node2 sent sessionEnd() to node1 + ) } - private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { - return transfer.message.topicSession == StateMachineManager.sessionTopic - && transfer.message.data.deserialize() is SessionData - } - - private inline fun MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P { + private inline fun > MockNode.restartAndGetRestoredProtocol( + networkMapNode: MockNode? = null): P { stop() - val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray()) + val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray()) mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine return newNode.getSingleProtocol

().first } @@ -216,35 +259,66 @@ class StateMachineManagerTests { return smm.findStateMachines(P::class.java).single() } + private fun sessionInit(initiatorNode: MockNode, protocolMarker: KClass<*>, payload: Any? = null): SessionInit { + return SessionInit(0, initiatorNode.info.legalIdentity, protocolMarker.java.name, payload) + } + + private fun sessionConfirm() = SessionConfirm(0, 0) + + private fun sessionData(payload: Any) = SessionData(0, payload) + + private fun sessionEnd() = SessionEnd(0) + + private fun assertSessionTransfers(vararg expected: SessionTransfer) { + assertThat(sessionTransfers).containsExactly(*expected) + } + + private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) { + val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id } + assertThat(actualForNode).containsExactly(*expected) + } + + private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) { + val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null + override fun toString(): String = "$from sent $message to $to" + } + + private fun Observable.toSessionTransfers(): Observable { + return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { + val from = it.sender.myAddress.id + val message = it.message.data.deserialize() + val to = (it.recipients as InMemoryMessagingNetwork.Handle).id + SessionTransfer(from, sanitise(message), to) + } + } + + private fun sanitise(message: SessionMessage): SessionMessage { + return when (message) { + is SessionData -> message.copy(recipientSessionId = 0) + is SessionInit -> message.copy(initiatorSessionId = 0) + is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0) + is SessionEnd -> message.copy(recipientSessionId = 0) + else -> message + } + } + + private infix fun MockNode.sent(message: SessionMessage): Pair = Pair(id, message) + private infix fun Pair.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id) + + + private class NoOpProtocol(val nonTerminating: Boolean = false) : ProtocolLogic() { - private class ProtocolNoBlocking : ProtocolLogic() { @Transient var protocolStarted = false @Suspendable override fun call() { protocolStarted = true + if (nonTerminating) { + Fiber.park() + } } } - private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() { - - @Transient var protocolStarted = false - - @Suspendable - override fun doCall() { - protocolStarted = true - } - } - - - private class ProtocolWithLazyServiceHub : ProtocolLogic() { - - val lazyTime by lazy { serviceHub.clock.instant() } - - @Suspendable - override fun call() = Unit - } - private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic() { @@ -257,7 +331,7 @@ class StateMachineManagerTests { } - private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() { + private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : ProtocolLogic() { init { require(otherParties.isNotEmpty()) @@ -266,8 +340,10 @@ class StateMachineManagerTests { @Transient var receivedPayloads: List = emptyList() @Suspendable - override fun doCall() { + override fun call() { receivedPayloads = otherParties.map { receive(it).unwrap { it } } + println(receivedPayloads) + Fiber.park() } } @@ -279,7 +355,9 @@ class StateMachineManagerTests { @Suspendable override fun call() { receivedPayload = sendAndReceive(otherParty, payload).unwrap { it } - receivedPayload2 = sendAndReceive(otherParty, (payload + 1)).unwrap { it } + println("${psm.id} Received $receivedPayload") + receivedPayload2 = sendAndReceive(otherParty, payload + 1).unwrap { it } + println("${psm.id} Received $receivedPayload2") } } @@ -287,20 +365,4 @@ class StateMachineManagerTests { override fun call(): Nothing = throw Exception() } - /** - * A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after - * restart for testing checkpoint restoration. Store any results as @Transient fields. - */ - private abstract class NonTerminatingProtocol : ProtocolLogic() { - - @Suspendable - override fun call() { - doCall() - Fiber.park() - } - - @Suspendable - abstract fun doCall() - } - } diff --git a/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt b/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt index 15993d5572..542105edf2 100644 --- a/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt +++ b/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt @@ -12,6 +12,7 @@ import com.r3corda.core.contracts.UniqueIdentifier import com.r3corda.core.flatMap import com.r3corda.core.map import com.r3corda.core.node.services.linearHeadsOfType +import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.success import com.r3corda.core.transactions.SignedTransaction import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor @@ -111,7 +112,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten irs.fixedLeg.fixedRatePayer = node1.info.legalIdentity irs.floatingLeg.floatingRatePayer = node2.info.legalIdentity - val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture } + val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { + (it.psm as ProtocolStateMachine).resultFuture + } showProgressFor(listOf(node1, node2)) showConsensusFor(listOf(node1, node2, regulators[0])) diff --git a/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt b/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt index 7b89d20eda..47cdfabaa6 100644 --- a/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt +++ b/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt @@ -11,6 +11,7 @@ import com.r3corda.core.contracts.`issued by` import com.r3corda.core.days import com.r3corda.core.flatMap import com.r3corda.core.node.recordTransactions +import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.seconds import com.r3corda.core.transactions.SignedTransaction import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer @@ -51,7 +52,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) { Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java) - }.flatMap { it.resultFuture } + }.flatMap { (it.psm as ProtocolStateMachine).resultFuture } val sellerKey = seller.services.legalIdentityKey val sellerProtocol = Seller( diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt b/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt index 591617ca34..a6e3327f0d 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt @@ -12,11 +12,11 @@ import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.node.ServiceHub import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.node.internal.AbstractNode +import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.StateMachineManager.Change import com.r3corda.node.utilities.AddOrRemove.ADD import com.r3corda.testing.node.MockIdentityService @@ -138,20 +138,20 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List { /** * The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty * protocol requests for it using [markerClass]. - * @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request. + * @return Returns a [ListenableFuture] holding the single [ProtocolStateMachineImpl] created by the request. */ -inline fun > AbstractNode.initiateSingleShotProtocol( +inline fun > AbstractNode.initiateSingleShotProtocol( markerClass: KClass>, - noinline protocolFactory: (Party) -> P): ListenableFuture> { + noinline protocolFactory: (Party) -> P): ListenableFuture

{ services.registerProtocolInitiator(markerClass, protocolFactory) - val future = SettableFuture.create>() + val future = SettableFuture.create

() val subscriber = object : Subscriber() { override fun onNext(change: Change) { if (change.logic is P && change.addOrRemove == ADD) { unsubscribe() - future.set(change.logic.psm as ProtocolStateMachine) + future.set(change.logic as P) } } override fun onError(e: Throwable) { diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/InMemoryMessagingNetwork.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/InMemoryMessagingNetwork.kt index 45f63e0c47..7cace615ca 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/InMemoryMessagingNetwork.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/InMemoryMessagingNetwork.kt @@ -223,7 +223,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria private val state = ThreadBox(InnerState()) private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) - override val myAddress: SingleMessageRecipient = handle + override val myAddress: Handle get() = handle private val backgroundThread = if (manuallyPumped) null else thread(isDaemon = true, name = "In-memory message dispatcher") { diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index f1727a627f..3cface82e5 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -1,5 +1,6 @@ package com.r3corda.testing.node +import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs import com.google.common.util.concurrent.Futures import com.r3corda.core.crypto.Party @@ -27,6 +28,7 @@ import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.ValidatingNotaryService import com.r3corda.node.utilities.databaseTransaction import org.slf4j.Logger +import java.nio.file.FileSystem import java.nio.file.Files import java.nio.file.Path import java.security.KeyPair @@ -49,7 +51,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, private val threadPerNode: Boolean = false, private val defaultFactory: Factory = MockNetwork.DefaultFactory) { private var counter = 0 - val filesystem = com.google.common.jimfs.Jimfs.newFileSystem(com.google.common.jimfs.Configuration.unix()) + val filesystem: FileSystem = Jimfs.newFileSystem(unix()) val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped) // A unique identifier for this network to segregate databases with the same nodeID but different networks. @@ -138,7 +140,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, // It is used from the network visualiser tool. @Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!! - fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? { + fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? { return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block) }