mirror of
https://github.com/corda/corda.git
synced 2025-01-19 03:06:36 +00:00
Session handshake optimised to carry the first send payload in the init message
This commit is contained in:
parent
d2983d6a7a
commit
e48e09f04e
@ -122,9 +122,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
payload: Any,
|
||||
receiveType: Class<T>,
|
||||
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
|
||||
val session = getSession(otherParty, sessionProtocol)
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
|
||||
val (session, new) = getSession(otherParty, sessionProtocol, payload)
|
||||
val receivedSessionData = if (new) {
|
||||
// Only do a receive here as the session init has carried the payload
|
||||
receiveInternal<SessionData>(session)
|
||||
} else {
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
sendAndReceiveInternal<SessionData>(session, sendSessionData)
|
||||
}
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
}
|
||||
|
||||
@ -132,15 +137,18 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
override fun <T : Any> receive(otherParty: Party,
|
||||
receiveType: Class<T>,
|
||||
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
|
||||
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
|
||||
val session = getSession(otherParty, sessionProtocol, null).first
|
||||
val receivedSessionData = receiveInternal<SessionData>(session)
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
|
||||
val session = getSession(otherParty, sessionProtocol)
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
sendInternal(session, sendSessionData)
|
||||
val (session, new) = getSession(otherParty, sessionProtocol, payload)
|
||||
if (!new) {
|
||||
// Don't send the payload again if it was already piggy-backed on a session init
|
||||
sendInternal(session, createSessionData(session, payload))
|
||||
}
|
||||
}
|
||||
|
||||
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
|
||||
@ -155,27 +163,31 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
|
||||
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
|
||||
private inline fun <reified M : SessionMessage> receiveInternal(session: ProtocolSession): M {
|
||||
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
|
||||
}
|
||||
|
||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage): M {
|
||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
|
||||
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
|
||||
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?): Pair<ProtocolSession, Boolean> {
|
||||
val session = openSessions[Pair(sessionProtocol, otherParty)]
|
||||
return if (session != null) {
|
||||
Pair(session, false)
|
||||
} else {
|
||||
Pair(startNewSession(otherParty, sessionProtocol, firstPayload), true)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
|
||||
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
|
||||
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?) : ProtocolSession {
|
||||
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
|
||||
openSessions[Pair(sessionProtocol, otherParty)] = session
|
||||
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
|
||||
val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol)
|
||||
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
|
||||
val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol, firstPayload)
|
||||
val sessionInitResponse = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
|
||||
if (sessionInitResponse is SessionConfirm) {
|
||||
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
|
||||
return session
|
||||
@ -186,21 +198,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
|
||||
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): M {
|
||||
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
|
||||
|
||||
val receivedMessage = getReceivedMessage() ?: run {
|
||||
// Suspend while we wait for the receive
|
||||
receiveRequest.session.waitingForResponse = true
|
||||
val polledMessage = getReceivedMessage()
|
||||
val receivedMessage = if (polledMessage != null) {
|
||||
if (receiveRequest is SendAndReceive) {
|
||||
// We've already received a message but we suspend so that the send can be performed
|
||||
suspend(receiveRequest)
|
||||
}
|
||||
polledMessage
|
||||
} else {
|
||||
// Suspend while we wait for a receive
|
||||
suspend(receiveRequest)
|
||||
receiveRequest.session.waitingForResponse = false
|
||||
getReceivedMessage()
|
||||
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
|
||||
}
|
||||
|
||||
if (receivedMessage is SessionEnd) {
|
||||
openSessions.values.remove(receiveRequest.session)
|
||||
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended")
|
||||
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended on $receiveRequest")
|
||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
|
||||
return receiveRequest.receiveType.cast(receivedMessage)
|
||||
} else {
|
||||
@ -213,6 +230,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
|
||||
txTrampoline = TransactionManager.currentOrNull()
|
||||
StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||
ioRequest.session.waitingForResponse = true
|
||||
parkAndSerialize { fiber, serializer ->
|
||||
logger.trace { "Suspended on $ioRequest" }
|
||||
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||
@ -228,6 +246,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
processException(t)
|
||||
}
|
||||
}
|
||||
ioRequest.session.waitingForResponse = false
|
||||
createTransaction()
|
||||
}
|
||||
|
||||
|
@ -233,6 +233,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
val protocol = protocolFactory(otherParty)
|
||||
val psm = createFiber(protocol)
|
||||
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
|
||||
if (sessionInit.firstPayload != null) {
|
||||
session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload)
|
||||
}
|
||||
openSessions[session.ourSessionId] = session
|
||||
psm.openSessions[Pair(protocol, otherParty)] = session
|
||||
updateCheckpoint(psm)
|
||||
@ -400,7 +403,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
val recipientSessionId: Long
|
||||
}
|
||||
|
||||
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
|
||||
data class SessionInit(val initiatorSessionId: Long,
|
||||
val initiatorParty: Party,
|
||||
val protocolName: String,
|
||||
val firstPayload: Any?) : SessionMessage
|
||||
|
||||
interface SessionInitResponse : ExistingSessionMessage
|
||||
|
||||
|
@ -7,6 +7,7 @@ import com.r3corda.core.contracts.*
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.days
|
||||
import com.r3corda.core.map
|
||||
import com.r3corda.core.messaging.SingleMessageRecipient
|
||||
import com.r3corda.core.node.services.*
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
@ -143,15 +144,15 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
// Everything is on this thread so we can now step through the protocol one step at a time.
|
||||
// Seller Alice already sent a message to Buyer Bob. Pump once:
|
||||
bobNode.pumpReceive(false)
|
||||
bobNode.pumpReceive()
|
||||
|
||||
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
aliceNode.pumpReceive()
|
||||
bobNode.pumpReceive()
|
||||
aliceNode.pumpReceive()
|
||||
bobNode.pumpReceive()
|
||||
aliceNode.pumpReceive()
|
||||
bobNode.pumpReceive()
|
||||
|
||||
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
|
||||
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
|
||||
@ -164,7 +165,7 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
|
||||
// She will wait around until Bob comes back.
|
||||
assertThat(aliceNode.pumpReceive(false)).isNotNull()
|
||||
assertThat(aliceNode.pumpReceive()).isNotNull()
|
||||
|
||||
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
|
||||
// that Bob was waiting on before the reboot occurred.
|
||||
@ -386,7 +387,7 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
private data class RunResult(
|
||||
// The buyer is not created immediately, only when the seller starts running
|
||||
val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
|
||||
val buyer: Future<ProtocolStateMachine<*>>,
|
||||
val sellerResult: Future<SignedTransaction>,
|
||||
val sellerId: StateMachineRunId
|
||||
)
|
||||
@ -394,7 +395,7 @@ class TwoPartyTradeProtocolTests {
|
||||
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
|
||||
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
|
||||
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
||||
}
|
||||
}.map { it.psm }
|
||||
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
|
||||
val sellerResultFuture = aliceNode.smm.add(seller).resultFuture
|
||||
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)
|
||||
|
@ -4,15 +4,15 @@ import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.SingleMessageRecipient
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolSessionException
|
||||
import com.r3corda.core.random63BitValue
|
||||
import com.r3corda.core.serialization.deserialize
|
||||
import com.r3corda.node.services.persistence.checkpoints
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
||||
import com.r3corda.testing.initiateSingleShotProtocol
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import com.r3corda.testing.node.MockNetwork.MockNode
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
@ -20,20 +20,25 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import rx.Observable
|
||||
import java.util.*
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class StateMachineManagerTests {
|
||||
|
||||
val net = MockNetwork()
|
||||
lateinit var node1: MockNode
|
||||
lateinit var node2: MockNode
|
||||
private val net = MockNetwork()
|
||||
private val sessionTransfers = ArrayList<SessionTransfer>()
|
||||
private lateinit var node1: MockNode
|
||||
private lateinit var node2: MockNode
|
||||
|
||||
@Before
|
||||
fun start() {
|
||||
val nodes = net.createTwoNodes()
|
||||
node1 = nodes.first
|
||||
node2 = nodes.second
|
||||
net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it }
|
||||
net.runNetwork()
|
||||
}
|
||||
|
||||
@ -44,14 +49,18 @@ class StateMachineManagerTests {
|
||||
|
||||
@Test
|
||||
fun `newly added protocol is preserved on restart`() {
|
||||
node1.smm.add(ProtocolWithoutCheckpoints())
|
||||
val restoredProtocol = node1.restartAndGetRestoredProtocol<ProtocolWithoutCheckpoints>()
|
||||
node1.smm.add(NoOpProtocol(nonTerminating = true))
|
||||
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
|
||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol can lazily use the serviceHub in its constructor`() {
|
||||
val protocol = ProtocolWithLazyServiceHub()
|
||||
val protocol = object : ProtocolLogic<Unit>() {
|
||||
val lazyTime by lazy { serviceHub.clock.instant() }
|
||||
@Suspendable
|
||||
override fun call() = Unit
|
||||
}
|
||||
node1.smm.add(protocol)
|
||||
assertThat(protocol.lazyTime).isNotNull()
|
||||
}
|
||||
@ -62,19 +71,18 @@ class StateMachineManagerTests {
|
||||
val payload = random63BitValue()
|
||||
node1.smm.add(SendProtocol(payload, node2.info.legalIdentity))
|
||||
|
||||
// We push through just enough messages to get only the SessionData sent
|
||||
// TODO We should be able to give runNetwork a predicate for when to stop
|
||||
net.runNetwork(2)
|
||||
// We push through just enough messages to get only the payload sent
|
||||
node2.pumpReceive()
|
||||
node2.stop()
|
||||
net.runNetwork()
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
|
||||
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol added before network map does run after init`() {
|
||||
val node3 = net.createNode(node1.info.address) //create vanilla node
|
||||
val protocol = ProtocolNoBlocking()
|
||||
val protocol = NoOpProtocol()
|
||||
node3.smm.add(protocol)
|
||||
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
@ -84,13 +92,13 @@ class StateMachineManagerTests {
|
||||
@Test
|
||||
fun `protocol added before network map will be init checkpointed`() {
|
||||
var node3 = net.createNode(node1.info.address) //create vanilla node
|
||||
val protocol = ProtocolNoBlocking()
|
||||
val protocol = NoOpProtocol()
|
||||
node3.smm.add(protocol)
|
||||
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
node3.stop()
|
||||
|
||||
node3 = net.createNode(node1.info.address, forcedID = node3.id)
|
||||
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
|
||||
val restoredProtocol = node3.getSingleProtocol<NoOpProtocol>().first
|
||||
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
node3.smm.executor.flush()
|
||||
@ -101,17 +109,16 @@ class StateMachineManagerTests {
|
||||
node3 = net.createNode(node1.info.address, forcedID = node3.id)
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
node3.smm.executor.flush()
|
||||
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty())
|
||||
assertTrue(node3.smm.findStateMachines(NoOpProtocol::class.java).isEmpty())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol loaded from checkpoint will respond to messages from before start`() {
|
||||
val payload = random63BitValue()
|
||||
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
|
||||
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.legalIdentity)
|
||||
node2.smm.add(receiveProtocol) // Prepare checkpointed receive protocol
|
||||
node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
|
||||
node2.stop() // kill receiver
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
|
||||
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@ -121,41 +128,36 @@ class StateMachineManagerTests {
|
||||
val payload2 = random63BitValue()
|
||||
|
||||
var sentCount = 0
|
||||
var receivedCount = 0
|
||||
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
|
||||
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
|
||||
val node3 = net.createNode(node1.info.address)
|
||||
net.runNetwork()
|
||||
net.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
|
||||
|
||||
var secondProtocol: PingPongProtocol? = null
|
||||
node3.services.registerProtocolInitiator(PingPongProtocol::class) {
|
||||
val protocol = PingPongProtocol(it, payload2)
|
||||
secondProtocol = protocol
|
||||
protocol
|
||||
}
|
||||
val node3 = net.createNode(node1.info.address)
|
||||
val secondProtocol = node3.initiateSingleShotProtocol(PingPongProtocol::class) { PingPongProtocol(it, payload2) }
|
||||
net.runNetwork()
|
||||
|
||||
// Kick off first send and receive
|
||||
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints().count())
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints().size)
|
||||
// Restart node and thus reload the checkpoint and resend the message with same UUID
|
||||
node2.stop()
|
||||
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
|
||||
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
|
||||
net.runNetwork()
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints().count())
|
||||
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
|
||||
net.runNetwork()
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints().size)
|
||||
node2b.smm.executor.flush()
|
||||
fut1.get()
|
||||
|
||||
val receivedCount = sessionTransfers.count { it.isPayloadTransfer }
|
||||
// Check protocols completed cleanly and didn't get out of phase
|
||||
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
|
||||
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
|
||||
assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
|
||||
assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
|
||||
// can't give a precise value as every addMessageHandler re-runs the undelivered messages
|
||||
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
|
||||
assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
|
||||
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
|
||||
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
|
||||
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
|
||||
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
||||
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
|
||||
assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
||||
assertEquals(payload + 1, secondProtocol.get().receivedPayload2, "Received payload does not match the expected second value on Node 2")
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -171,6 +173,20 @@ class StateMachineManagerTests {
|
||||
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
|
||||
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
|
||||
assertSessionTransfers(node2,
|
||||
node1 sent sessionInit(node1, SendProtocol::class, payload) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node1 sent sessionEnd() to node2
|
||||
//There's no session end from the other protocols as they're manually suspended
|
||||
)
|
||||
|
||||
assertSessionTransfers(node3,
|
||||
node1 sent sessionInit(node1, SendProtocol::class, payload) to node3,
|
||||
node3 sent sessionConfirm() to node1,
|
||||
node1 sent sessionEnd() to node3
|
||||
//There's no session end from the other protocols as they're manually suspended
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -183,13 +199,39 @@ class StateMachineManagerTests {
|
||||
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
|
||||
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
|
||||
node1.smm.add(multiReceiveProtocol)
|
||||
net.runNetwork(1) // session handshaking
|
||||
// have the messages arrive in reverse order of receive
|
||||
node3.pumpReceive(false)
|
||||
node2.pumpReceive(false)
|
||||
net.runNetwork() // pump remaining messages
|
||||
net.runNetwork()
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
|
||||
|
||||
assertSessionTransfers(node2,
|
||||
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionData(node2Payload) to node1,
|
||||
node2 sent sessionEnd() to node1
|
||||
)
|
||||
|
||||
assertSessionTransfers(node3,
|
||||
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node3,
|
||||
node3 sent sessionConfirm() to node1,
|
||||
node3 sent sessionData(node3Payload) to node1,
|
||||
node3 sent sessionEnd() to node1
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `both sides do a send as their first IO request`() {
|
||||
node2.services.registerProtocolInitiator(PingPongProtocol::class) { PingPongProtocol(it, 20L) }
|
||||
node1.smm.add(PingPongProtocol(node2.info.legalIdentity, 10L))
|
||||
net.runNetwork()
|
||||
|
||||
assertSessionTransfers(
|
||||
node1 sent sessionInit(node1, PingPongProtocol::class, 10L) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionData(20L) to node1,
|
||||
node1 sent sessionData(11L) to node2,
|
||||
node2 sent sessionData(21L) to node1,
|
||||
node1 sent sessionEnd() to node2
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -198,16 +240,17 @@ class StateMachineManagerTests {
|
||||
val future = node1.smm.add(ReceiveThenSuspendProtocol(node2.info.legalIdentity)).resultFuture
|
||||
net.runNetwork()
|
||||
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
|
||||
assertSessionTransfers(
|
||||
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionEnd() to node1
|
||||
)
|
||||
}
|
||||
|
||||
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
|
||||
return transfer.message.topicSession == StateMachineManager.sessionTopic
|
||||
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
|
||||
}
|
||||
|
||||
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
|
||||
private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol(
|
||||
networkMapNode: MockNode? = null): P {
|
||||
stop()
|
||||
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
|
||||
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
|
||||
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
||||
return newNode.getSingleProtocol<P>().first
|
||||
}
|
||||
@ -216,35 +259,66 @@ class StateMachineManagerTests {
|
||||
return smm.findStateMachines(P::class.java).single()
|
||||
}
|
||||
|
||||
private fun sessionInit(initiatorNode: MockNode, protocolMarker: KClass<*>, payload: Any? = null): SessionInit {
|
||||
return SessionInit(0, initiatorNode.info.legalIdentity, protocolMarker.java.name, payload)
|
||||
}
|
||||
|
||||
private fun sessionConfirm() = SessionConfirm(0, 0)
|
||||
|
||||
private fun sessionData(payload: Any) = SessionData(0, payload)
|
||||
|
||||
private fun sessionEnd() = SessionEnd(0)
|
||||
|
||||
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
|
||||
assertThat(sessionTransfers).containsExactly(*expected)
|
||||
}
|
||||
|
||||
private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) {
|
||||
val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id }
|
||||
assertThat(actualForNode).containsExactly(*expected)
|
||||
}
|
||||
|
||||
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) {
|
||||
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
|
||||
override fun toString(): String = "$from sent $message to $to"
|
||||
}
|
||||
|
||||
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
|
||||
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
|
||||
val from = it.sender.myAddress.id
|
||||
val message = it.message.data.deserialize<SessionMessage>()
|
||||
val to = (it.recipients as InMemoryMessagingNetwork.Handle).id
|
||||
SessionTransfer(from, sanitise(message), to)
|
||||
}
|
||||
}
|
||||
|
||||
private fun sanitise(message: SessionMessage): SessionMessage {
|
||||
return when (message) {
|
||||
is SessionData -> message.copy(recipientSessionId = 0)
|
||||
is SessionInit -> message.copy(initiatorSessionId = 0)
|
||||
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0)
|
||||
is SessionEnd -> message.copy(recipientSessionId = 0)
|
||||
else -> message
|
||||
}
|
||||
}
|
||||
|
||||
private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message)
|
||||
private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id)
|
||||
|
||||
|
||||
private class NoOpProtocol(val nonTerminating: Boolean = false) : ProtocolLogic<Unit>() {
|
||||
|
||||
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
|
||||
@Transient var protocolStarted = false
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
protocolStarted = true
|
||||
if (nonTerminating) {
|
||||
Fiber.park()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
|
||||
|
||||
@Transient var protocolStarted = false
|
||||
|
||||
@Suspendable
|
||||
override fun doCall() {
|
||||
protocolStarted = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
|
||||
|
||||
val lazyTime by lazy { serviceHub.clock.instant() }
|
||||
|
||||
@Suspendable
|
||||
override fun call() = Unit
|
||||
}
|
||||
|
||||
|
||||
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
|
||||
|
||||
@ -257,7 +331,7 @@ class StateMachineManagerTests {
|
||||
}
|
||||
|
||||
|
||||
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
|
||||
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : ProtocolLogic<Unit>() {
|
||||
|
||||
init {
|
||||
require(otherParties.isNotEmpty())
|
||||
@ -266,8 +340,10 @@ class StateMachineManagerTests {
|
||||
@Transient var receivedPayloads: List<Any> = emptyList()
|
||||
|
||||
@Suspendable
|
||||
override fun doCall() {
|
||||
override fun call() {
|
||||
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
|
||||
println(receivedPayloads)
|
||||
Fiber.park()
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,7 +355,9 @@ class StateMachineManagerTests {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
|
||||
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
|
||||
println("${psm.id} Received $receivedPayload")
|
||||
receivedPayload2 = sendAndReceive<Long>(otherParty, payload + 1).unwrap { it }
|
||||
println("${psm.id} Received $receivedPayload2")
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,20 +365,4 @@ class StateMachineManagerTests {
|
||||
override fun call(): Nothing = throw Exception()
|
||||
}
|
||||
|
||||
/**
|
||||
* A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after
|
||||
* restart for testing checkpoint restoration. Store any results as @Transient fields.
|
||||
*/
|
||||
private abstract class NonTerminatingProtocol : ProtocolLogic<Unit>() {
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
doCall()
|
||||
Fiber.park()
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
abstract fun doCall()
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import com.r3corda.core.contracts.UniqueIdentifier
|
||||
import com.r3corda.core.flatMap
|
||||
import com.r3corda.core.map
|
||||
import com.r3corda.core.node.services.linearHeadsOfType
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.success
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
|
||||
@ -111,7 +112,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
|
||||
irs.fixedLeg.fixedRatePayer = node1.info.legalIdentity
|
||||
irs.floatingLeg.floatingRatePayer = node2.info.legalIdentity
|
||||
|
||||
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture }
|
||||
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap {
|
||||
(it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture
|
||||
}
|
||||
|
||||
showProgressFor(listOf(node1, node2))
|
||||
showConsensusFor(listOf(node1, node2, regulators[0]))
|
||||
|
@ -11,6 +11,7 @@ import com.r3corda.core.contracts.`issued by`
|
||||
import com.r3corda.core.days
|
||||
import com.r3corda.core.flatMap
|
||||
import com.r3corda.core.node.recordTransactions
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.seconds
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
||||
@ -51,7 +52,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
|
||||
|
||||
val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) {
|
||||
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
|
||||
}.flatMap { it.resultFuture }
|
||||
}.flatMap { (it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture }
|
||||
|
||||
val sellerKey = seller.services.legalIdentityKey
|
||||
val sellerProtocol = Seller(
|
||||
|
@ -12,11 +12,11 @@ import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.crypto.generateKeyPair
|
||||
import com.r3corda.core.node.ServiceHub
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.transactions.TransactionBuilder
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
||||
import com.r3corda.node.internal.AbstractNode
|
||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.Change
|
||||
import com.r3corda.node.utilities.AddOrRemove.ADD
|
||||
import com.r3corda.testing.node.MockIdentityService
|
||||
@ -138,20 +138,20 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
|
||||
/**
|
||||
* The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty
|
||||
* protocol requests for it using [markerClass].
|
||||
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request.
|
||||
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachineImpl] created by the request.
|
||||
*/
|
||||
inline fun <R, reified P : ProtocolLogic<R>> AbstractNode.initiateSingleShotProtocol(
|
||||
inline fun <reified P : ProtocolLogic<*>> AbstractNode.initiateSingleShotProtocol(
|
||||
markerClass: KClass<out ProtocolLogic<*>>,
|
||||
noinline protocolFactory: (Party) -> P): ListenableFuture<ProtocolStateMachine<R>> {
|
||||
noinline protocolFactory: (Party) -> P): ListenableFuture<P> {
|
||||
services.registerProtocolInitiator(markerClass, protocolFactory)
|
||||
|
||||
val future = SettableFuture.create<ProtocolStateMachine<R>>()
|
||||
val future = SettableFuture.create<P>()
|
||||
|
||||
val subscriber = object : Subscriber<Change>() {
|
||||
override fun onNext(change: Change) {
|
||||
if (change.logic is P && change.addOrRemove == ADD) {
|
||||
unsubscribe()
|
||||
future.set(change.logic.psm as ProtocolStateMachine<R>)
|
||||
future.set(change.logic as P)
|
||||
}
|
||||
}
|
||||
override fun onError(e: Throwable) {
|
||||
|
@ -223,7 +223,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
|
||||
private val state = ThreadBox(InnerState())
|
||||
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
|
||||
|
||||
override val myAddress: SingleMessageRecipient = handle
|
||||
override val myAddress: Handle get() = handle
|
||||
|
||||
private val backgroundThread = if (manuallyPumped) null else
|
||||
thread(isDaemon = true, name = "In-memory message dispatcher") {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package com.r3corda.testing.node
|
||||
|
||||
import com.google.common.jimfs.Configuration.unix
|
||||
import com.google.common.jimfs.Jimfs
|
||||
import com.google.common.util.concurrent.Futures
|
||||
import com.r3corda.core.crypto.Party
|
||||
@ -27,6 +28,7 @@ import com.r3corda.node.services.transactions.SimpleNotaryService
|
||||
import com.r3corda.node.services.transactions.ValidatingNotaryService
|
||||
import com.r3corda.node.utilities.databaseTransaction
|
||||
import org.slf4j.Logger
|
||||
import java.nio.file.FileSystem
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import java.security.KeyPair
|
||||
@ -49,7 +51,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
private val threadPerNode: Boolean = false,
|
||||
private val defaultFactory: Factory = MockNetwork.DefaultFactory) {
|
||||
private var counter = 0
|
||||
val filesystem = com.google.common.jimfs.Jimfs.newFileSystem(com.google.common.jimfs.Configuration.unix())
|
||||
val filesystem: FileSystem = Jimfs.newFileSystem(unix())
|
||||
val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped)
|
||||
|
||||
// A unique identifier for this network to segregate databases with the same nodeID but different networks.
|
||||
@ -138,7 +140,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
// It is used from the network visualiser tool.
|
||||
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
|
||||
|
||||
fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? {
|
||||
fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? {
|
||||
return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user