Session handshake optimised to carry the first send payload in the init message

This commit is contained in:
Shams Asari 2016-10-10 15:29:45 +01:00
parent d2983d6a7a
commit e48e09f04e
9 changed files with 231 additions and 137 deletions

View File

@ -122,9 +122,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
payload: Any, payload: Any,
receiveType: Class<T>, receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> { sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionProtocol) val (session, new) = getSession(otherParty, sessionProtocol, payload)
val sendSessionData = createSessionData(session, payload) val receivedSessionData = if (new) {
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java) // Only do a receive here as the session init has carried the payload
receiveInternal<SessionData>(session)
} else {
val sendSessionData = createSessionData(session, payload)
sendAndReceiveInternal<SessionData>(session, sendSessionData)
}
return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
} }
@ -132,15 +137,18 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
override fun <T : Any> receive(otherParty: Party, override fun <T : Any> receive(otherParty: Party,
receiveType: Class<T>, receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> { sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java) val session = getSession(otherParty, sessionProtocol, null).first
val receivedSessionData = receiveInternal<SessionData>(session)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
} }
@Suspendable @Suspendable
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) { override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
val session = getSession(otherParty, sessionProtocol) val (session, new) = getSession(otherParty, sessionProtocol, payload)
val sendSessionData = createSessionData(session, payload) if (!new) {
sendInternal(session, sendSessionData) // Don't send the payload again if it was already piggy-backed on a session init
sendInternal(session, createSessionData(session, payload))
}
} }
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData { private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
@ -155,27 +163,31 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T { private inline fun <reified M : SessionMessage> receiveInternal(session: ProtocolSession): M {
return suspendAndExpectReceive(ReceiveOnly(session, receiveType)) return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage): M {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
} }
@Suspendable @Suspendable
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T { private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?): Pair<ProtocolSession, Boolean> {
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType)) val session = openSessions[Pair(sessionProtocol, otherParty)]
return if (session != null) {
Pair(session, false)
} else {
Pair(startNewSession(otherParty, sessionProtocol, firstPayload), true)
}
} }
@Suspendable @Suspendable
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession { private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?) : ProtocolSession {
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
}
@Suspendable
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null) val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
openSessions[Pair(sessionProtocol, otherParty)] = session openSessions[Pair(sessionProtocol, otherParty)] = session
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol) val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol, firstPayload)
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java) val sessionInitResponse = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
if (sessionInitResponse is SessionConfirm) { if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId session.otherPartySessionId = sessionInitResponse.initiatedSessionId
return session return session
@ -186,21 +198,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T { private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): M {
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll() fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
val receivedMessage = getReceivedMessage() ?: run { val polledMessage = getReceivedMessage()
// Suspend while we wait for the receive val receivedMessage = if (polledMessage != null) {
receiveRequest.session.waitingForResponse = true if (receiveRequest is SendAndReceive) {
// We've already received a message but we suspend so that the send can be performed
suspend(receiveRequest)
}
polledMessage
} else {
// Suspend while we wait for a receive
suspend(receiveRequest) suspend(receiveRequest)
receiveRequest.session.waitingForResponse = false
getReceivedMessage() getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest") ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
} }
if (receivedMessage is SessionEnd) { if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session) openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended") throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended on $receiveRequest")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) { } else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage) return receiveRequest.receiveType.cast(receivedMessage)
} else { } else {
@ -213,6 +230,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull() txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null) StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = true
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" } logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
@ -228,6 +246,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
processException(t) processException(t)
} }
} }
ioRequest.session.waitingForResponse = false
createTransaction() createTransaction()
} }

View File

@ -233,6 +233,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val protocol = protocolFactory(otherParty) val protocol = protocolFactory(otherParty)
val psm = createFiber(protocol) val psm = createFiber(protocol)
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId) val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
if (sessionInit.firstPayload != null) {
session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload)
}
openSessions[session.ourSessionId] = session openSessions[session.ourSessionId] = session
psm.openSessions[Pair(protocol, otherParty)] = session psm.openSessions[Pair(protocol, otherParty)] = session
updateCheckpoint(psm) updateCheckpoint(psm)
@ -400,7 +403,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val recipientSessionId: Long val recipientSessionId: Long
} }
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage data class SessionInit(val initiatorSessionId: Long,
val initiatorParty: Party,
val protocolName: String,
val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage interface SessionInitResponse : ExistingSessionMessage

View File

@ -7,6 +7,7 @@ import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.map
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.* import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.ProtocolStateMachine
@ -143,15 +144,15 @@ class TwoPartyTradeProtocolTests {
// Everything is on this thread so we can now step through the protocol one step at a time. // Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once: // Seller Alice already sent a message to Buyer Bob. Pump once:
bobNode.pumpReceive(false) bobNode.pumpReceive()
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds. // Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
aliceNode.pumpReceive(false) aliceNode.pumpReceive()
bobNode.pumpReceive(false) bobNode.pumpReceive()
aliceNode.pumpReceive(false) aliceNode.pumpReceive()
bobNode.pumpReceive(false) bobNode.pumpReceive()
aliceNode.pumpReceive(false) aliceNode.pumpReceive()
bobNode.pumpReceive(false) bobNode.pumpReceive()
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
@ -164,7 +165,7 @@ class TwoPartyTradeProtocolTests {
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use. // Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
// She will wait around until Bob comes back. // She will wait around until Bob comes back.
assertThat(aliceNode.pumpReceive(false)).isNotNull() assertThat(aliceNode.pumpReceive()).isNotNull()
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
// that Bob was waiting on before the reboot occurred. // that Bob was waiting on before the reboot occurred.
@ -386,7 +387,7 @@ class TwoPartyTradeProtocolTests {
private data class RunResult( private data class RunResult(
// The buyer is not created immediately, only when the seller starts running // The buyer is not created immediately, only when the seller starts running
val buyer: Future<ProtocolStateMachine<SignedTransaction>>, val buyer: Future<ProtocolStateMachine<*>>,
val sellerResult: Future<SignedTransaction>, val sellerResult: Future<SignedTransaction>,
val sellerId: StateMachineRunId val sellerId: StateMachineRunId
) )
@ -394,7 +395,7 @@ class TwoPartyTradeProtocolTests {
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult { private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty -> val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java) Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
} }.map { it.psm }
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY) val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
val sellerResultFuture = aliceNode.smm.add(seller).resultFuture val sellerResultFuture = aliceNode.smm.add(seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id) return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)

View File

@ -4,15 +4,15 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.persistence.checkpoints import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@ -20,20 +20,25 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Observable
import java.util.*
import kotlin.reflect.KClass
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class StateMachineManagerTests { class StateMachineManagerTests {
val net = MockNetwork() private val net = MockNetwork()
lateinit var node1: MockNode private val sessionTransfers = ArrayList<SessionTransfer>()
lateinit var node2: MockNode private lateinit var node1: MockNode
private lateinit var node2: MockNode
@Before @Before
fun start() { fun start() {
val nodes = net.createTwoNodes() val nodes = net.createTwoNodes()
node1 = nodes.first node1 = nodes.first
node2 = nodes.second node2 = nodes.second
net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it }
net.runNetwork() net.runNetwork()
} }
@ -44,14 +49,18 @@ class StateMachineManagerTests {
@Test @Test
fun `newly added protocol is preserved on restart`() { fun `newly added protocol is preserved on restart`() {
node1.smm.add(ProtocolWithoutCheckpoints()) node1.smm.add(NoOpProtocol(nonTerminating = true))
val restoredProtocol = node1.restartAndGetRestoredProtocol<ProtocolWithoutCheckpoints>() val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
assertThat(restoredProtocol.protocolStarted).isTrue() assertThat(restoredProtocol.protocolStarted).isTrue()
} }
@Test @Test
fun `protocol can lazily use the serviceHub in its constructor`() { fun `protocol can lazily use the serviceHub in its constructor`() {
val protocol = ProtocolWithLazyServiceHub() val protocol = object : ProtocolLogic<Unit>() {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
node1.smm.add(protocol) node1.smm.add(protocol)
assertThat(protocol.lazyTime).isNotNull() assertThat(protocol.lazyTime).isNotNull()
} }
@ -62,19 +71,18 @@ class StateMachineManagerTests {
val payload = random63BitValue() val payload = random63BitValue()
node1.smm.add(SendProtocol(payload, node2.info.legalIdentity)) node1.smm.add(SendProtocol(payload, node2.info.legalIdentity))
// We push through just enough messages to get only the SessionData sent // We push through just enough messages to get only the payload sent
// TODO We should be able to give runNetwork a predicate for when to stop node2.pumpReceive()
net.runNetwork(2)
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
} }
@Test @Test
fun `protocol added before network map does run after init`() { fun `protocol added before network map does run after init`() {
val node3 = net.createNode(node1.info.address) //create vanilla node val node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking() val protocol = NoOpProtocol()
node3.smm.add(protocol) node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow net.runNetwork() // Allow network map messages to flow
@ -84,13 +92,13 @@ class StateMachineManagerTests {
@Test @Test
fun `protocol added before network map will be init checkpointed`() { fun `protocol added before network map will be init checkpointed`() {
var node3 = net.createNode(node1.info.address) //create vanilla node var node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking() val protocol = NoOpProtocol()
node3.smm.add(protocol) node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.stop() node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id) node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first val restoredProtocol = node3.getSingleProtocol<NoOpProtocol>().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush() node3.smm.executor.flush()
@ -101,17 +109,16 @@ class StateMachineManagerTests {
node3 = net.createNode(node1.info.address, forcedID = node3.id) node3 = net.createNode(node1.info.address, forcedID = node3.id)
net.runNetwork() // Allow network map messages to flow net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush() node3.smm.executor.flush()
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty()) assertTrue(node3.smm.findStateMachines(NoOpProtocol::class.java).isEmpty())
} }
@Test @Test
fun `protocol loaded from checkpoint will respond to messages from before start`() { fun `protocol loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue() val payload = random63BitValue()
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) } node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.legalIdentity) node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
node2.smm.add(receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver node2.stop() // kill receiver
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
} }
@ -121,41 +128,36 @@ class StateMachineManagerTests {
val payload2 = random63BitValue() val payload2 = random63BitValue()
var sentCount = 0 var sentCount = 0
var receivedCount = 0 net.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
val node3 = net.createNode(node1.info.address)
net.runNetwork()
var secondProtocol: PingPongProtocol? = null val node3 = net.createNode(node1.info.address)
node3.services.registerProtocolInitiator(PingPongProtocol::class) { val secondProtocol = node3.initiateSingleShotProtocol(PingPongProtocol::class) { PingPongProtocol(it, payload2) }
val protocol = PingPongProtocol(it, payload2) net.runNetwork()
secondProtocol = protocol
protocol
}
// Kick off first send and receive // Kick off first send and receive
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload)) node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints().count()) assertEquals(1, node2.checkpointStorage.checkpoints().size)
// Restart node and thus reload the checkpoint and resend the message with same UUID // Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop() node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>() val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork() net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().size)
node2b.smm.executor.flush() node2b.smm.executor.flush()
fut1.get() fut1.get()
val receivedCount = sessionTransfers.count { it.isPayloadTransfer }
// Check protocols completed cleanly and didn't get out of phase // Check protocols completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages // can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2") assertEquals(payload + 1, secondProtocol.get().receivedPayload2, "Received payload does not match the expected second value on Node 2")
} }
@Test @Test
@ -171,6 +173,20 @@ class StateMachineManagerTests {
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload) assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload) assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
assertSessionTransfers(node2,
node1 sent sessionInit(node1, SendProtocol::class, payload) to node2,
node2 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node2
//There's no session end from the other protocols as they're manually suspended
)
assertSessionTransfers(node3,
node1 sent sessionInit(node1, SendProtocol::class, payload) to node3,
node3 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node3
//There's no session end from the other protocols as they're manually suspended
)
} }
@Test @Test
@ -183,13 +199,39 @@ class StateMachineManagerTests {
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity) val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
node1.smm.add(multiReceiveProtocol) node1.smm.add(multiReceiveProtocol)
net.runNetwork(1) // session handshaking net.runNetwork()
// have the messages arrive in reverse order of receive
node3.pumpReceive(false)
node2.pumpReceive(false)
net.runNetwork() // pump remaining messages
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
assertSessionTransfers(node2,
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1
)
assertSessionTransfers(node3,
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node3,
node3 sent sessionConfirm() to node1,
node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1
)
}
@Test
fun `both sides do a send as their first IO request`() {
node2.services.registerProtocolInitiator(PingPongProtocol::class) { PingPongProtocol(it, 20L) }
node1.smm.add(PingPongProtocol(node2.info.legalIdentity, 10L))
net.runNetwork()
assertSessionTransfers(
node1 sent sessionInit(node1, PingPongProtocol::class, 10L) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1,
node1 sent sessionEnd() to node2
)
} }
@Test @Test
@ -198,16 +240,17 @@ class StateMachineManagerTests {
val future = node1.smm.add(ReceiveThenSuspendProtocol(node2.info.legalIdentity)).resultFuture val future = node1.smm.add(ReceiveThenSuspendProtocol(node2.info.legalIdentity)).resultFuture
net.runNetwork() net.runNetwork()
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java) assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
assertSessionTransfers(
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionEnd() to node1
)
} }
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol(
return transfer.message.topicSession == StateMachineManager.sessionTopic networkMapNode: MockNode? = null): P {
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
stop() stop()
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray()) val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first return newNode.getSingleProtocol<P>().first
} }
@ -216,35 +259,66 @@ class StateMachineManagerTests {
return smm.findStateMachines(P::class.java).single() return smm.findStateMachines(P::class.java).single()
} }
private fun sessionInit(initiatorNode: MockNode, protocolMarker: KClass<*>, payload: Any? = null): SessionInit {
return SessionInit(0, initiatorNode.info.legalIdentity, protocolMarker.java.name, payload)
}
private fun sessionConfirm() = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionEnd() = SessionEnd(0)
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected)
}
private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) {
val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id }
assertThat(actualForNode).containsExactly(*expected)
}
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to"
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
val from = it.sender.myAddress.id
val message = it.message.data.deserialize<SessionMessage>()
val to = (it.recipients as InMemoryMessagingNetwork.Handle).id
SessionTransfer(from, sanitise(message), to)
}
}
private fun sanitise(message: SessionMessage): SessionMessage {
return when (message) {
is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0)
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0)
is SessionEnd -> message.copy(recipientSessionId = 0)
else -> message
}
}
private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message)
private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id)
private class NoOpProtocol(val nonTerminating: Boolean = false) : ProtocolLogic<Unit>() {
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false @Transient var protocolStarted = false
@Suspendable @Suspendable
override fun call() { override fun call() {
protocolStarted = true protocolStarted = true
if (nonTerminating) {
Fiber.park()
}
} }
} }
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@Transient var protocolStarted = false
@Suspendable
override fun doCall() {
protocolStarted = true
}
}
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() { private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
@ -257,7 +331,7 @@ class StateMachineManagerTests {
} }
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() { private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : ProtocolLogic<Unit>() {
init { init {
require(otherParties.isNotEmpty()) require(otherParties.isNotEmpty())
@ -266,8 +340,10 @@ class StateMachineManagerTests {
@Transient var receivedPayloads: List<Any> = emptyList() @Transient var receivedPayloads: List<Any> = emptyList()
@Suspendable @Suspendable
override fun doCall() { override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } } receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
println(receivedPayloads)
Fiber.park()
} }
} }
@ -279,7 +355,9 @@ class StateMachineManagerTests {
@Suspendable @Suspendable
override fun call() { override fun call() {
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it } receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it } println("${psm.id} Received $receivedPayload")
receivedPayload2 = sendAndReceive<Long>(otherParty, payload + 1).unwrap { it }
println("${psm.id} Received $receivedPayload2")
} }
} }
@ -287,20 +365,4 @@ class StateMachineManagerTests {
override fun call(): Nothing = throw Exception() override fun call(): Nothing = throw Exception()
} }
/**
* A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after
* restart for testing checkpoint restoration. Store any results as @Transient fields.
*/
private abstract class NonTerminatingProtocol : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
doCall()
Fiber.park()
}
@Suspendable
abstract fun doCall()
}
} }

View File

@ -12,6 +12,7 @@ import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.flatMap import com.r3corda.core.flatMap
import com.r3corda.core.map import com.r3corda.core.map
import com.r3corda.core.node.services.linearHeadsOfType import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.success import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
@ -111,7 +112,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
irs.fixedLeg.fixedRatePayer = node1.info.legalIdentity irs.fixedLeg.fixedRatePayer = node1.info.legalIdentity
irs.floatingLeg.floatingRatePayer = node2.info.legalIdentity irs.floatingLeg.floatingRatePayer = node2.info.legalIdentity
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture } val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap {
(it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture
}
showProgressFor(listOf(node1, node2)) showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0])) showConsensusFor(listOf(node1, node2, regulators[0]))

View File

@ -11,6 +11,7 @@ import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.flatMap import com.r3corda.core.flatMap
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
@ -51,7 +52,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) { val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) {
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java) Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
}.flatMap { it.resultFuture } }.flatMap { (it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture }
val sellerKey = seller.services.legalIdentityKey val sellerKey = seller.services.legalIdentityKey
val sellerProtocol = Seller( val sellerProtocol = Seller(

View File

@ -12,11 +12,11 @@ import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import com.r3corda.node.services.statemachine.StateMachineManager.Change import com.r3corda.node.services.statemachine.StateMachineManager.Change
import com.r3corda.node.utilities.AddOrRemove.ADD import com.r3corda.node.utilities.AddOrRemove.ADD
import com.r3corda.testing.node.MockIdentityService import com.r3corda.testing.node.MockIdentityService
@ -138,20 +138,20 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
/** /**
* The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty * The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty
* protocol requests for it using [markerClass]. * protocol requests for it using [markerClass].
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request. * @return Returns a [ListenableFuture] holding the single [ProtocolStateMachineImpl] created by the request.
*/ */
inline fun <R, reified P : ProtocolLogic<R>> AbstractNode.initiateSingleShotProtocol( inline fun <reified P : ProtocolLogic<*>> AbstractNode.initiateSingleShotProtocol(
markerClass: KClass<out ProtocolLogic<*>>, markerClass: KClass<out ProtocolLogic<*>>,
noinline protocolFactory: (Party) -> P): ListenableFuture<ProtocolStateMachine<R>> { noinline protocolFactory: (Party) -> P): ListenableFuture<P> {
services.registerProtocolInitiator(markerClass, protocolFactory) services.registerProtocolInitiator(markerClass, protocolFactory)
val future = SettableFuture.create<ProtocolStateMachine<R>>() val future = SettableFuture.create<P>()
val subscriber = object : Subscriber<Change>() { val subscriber = object : Subscriber<Change>() {
override fun onNext(change: Change) { override fun onNext(change: Change) {
if (change.logic is P && change.addOrRemove == ADD) { if (change.logic is P && change.addOrRemove == ADD) {
unsubscribe() unsubscribe()
future.set(change.logic.psm as ProtocolStateMachine<R>) future.set(change.logic as P)
} }
} }
override fun onError(e: Throwable) { override fun onError(e: Throwable) {

View File

@ -223,7 +223,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>()) private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
override val myAddress: SingleMessageRecipient = handle override val myAddress: Handle get() = handle
private val backgroundThread = if (manuallyPumped) null else private val backgroundThread = if (manuallyPumped) null else
thread(isDaemon = true, name = "In-memory message dispatcher") { thread(isDaemon = true, name = "In-memory message dispatcher") {

View File

@ -1,5 +1,6 @@
package com.r3corda.testing.node package com.r3corda.testing.node
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
@ -27,6 +28,7 @@ import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.utilities.databaseTransaction import com.r3corda.node.utilities.databaseTransaction
import org.slf4j.Logger import org.slf4j.Logger
import java.nio.file.FileSystem
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
@ -49,7 +51,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
private val threadPerNode: Boolean = false, private val threadPerNode: Boolean = false,
private val defaultFactory: Factory = MockNetwork.DefaultFactory) { private val defaultFactory: Factory = MockNetwork.DefaultFactory) {
private var counter = 0 private var counter = 0
val filesystem = com.google.common.jimfs.Jimfs.newFileSystem(com.google.common.jimfs.Configuration.unix()) val filesystem: FileSystem = Jimfs.newFileSystem(unix())
val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped) val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped)
// A unique identifier for this network to segregate databases with the same nodeID but different networks. // A unique identifier for this network to segregate databases with the same nodeID but different networks.
@ -138,7 +140,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// It is used from the network visualiser tool. // It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!! @Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? { fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? {
return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block) return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block)
} }