Session handshake optimised to carry the first send payload in the init message

This commit is contained in:
Shams Asari 2016-10-10 15:29:45 +01:00
parent d2983d6a7a
commit e48e09f04e
9 changed files with 231 additions and 137 deletions

View File

@ -122,9 +122,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
payload: Any,
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
val (session, new) = getSession(otherParty, sessionProtocol, payload)
val receivedSessionData = if (new) {
// Only do a receive here as the session init has carried the payload
receiveInternal<SessionData>(session)
} else {
val sendSessionData = createSessionData(session, payload)
sendAndReceiveInternal<SessionData>(session, sendSessionData)
}
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@ -132,15 +137,18 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
override fun <T : Any> receive(otherParty: Party,
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
val session = getSession(otherParty, sessionProtocol, null).first
val receivedSessionData = receiveInternal<SessionData>(session)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@Suspendable
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
sendInternal(session, sendSessionData)
val (session, new) = getSession(otherParty, sessionProtocol, payload)
if (!new) {
// Don't send the payload again if it was already piggy-backed on a session init
sendInternal(session, createSessionData(session, payload))
}
}
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
@ -155,27 +163,31 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
private inline fun <reified M : SessionMessage> receiveInternal(session: ProtocolSession): M {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage): M {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
}
@Suspendable
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?): Pair<ProtocolSession, Boolean> {
val session = openSessions[Pair(sessionProtocol, otherParty)]
return if (session != null) {
Pair(session, false)
} else {
Pair(startNewSession(otherParty, sessionProtocol, firstPayload), true)
}
}
@Suspendable
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
}
@Suspendable
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>, firstPayload: Any?) : ProtocolSession {
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
openSessions[Pair(sessionProtocol, otherParty)] = session
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol)
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
val sessionInit = SessionInit(session.ourSessionId, serviceHub.myInfo.legalIdentity, counterpartyProtocol, firstPayload)
val sessionInitResponse = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
return session
@ -186,21 +198,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): M {
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
val receivedMessage = getReceivedMessage() ?: run {
// Suspend while we wait for the receive
receiveRequest.session.waitingForResponse = true
val polledMessage = getReceivedMessage()
val receivedMessage = if (polledMessage != null) {
if (receiveRequest is SendAndReceive) {
// We've already received a message but we suspend so that the send can be performed
suspend(receiveRequest)
}
polledMessage
} else {
// Suspend while we wait for a receive
suspend(receiveRequest)
receiveRequest.session.waitingForResponse = false
getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
}
if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended")
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended on $receiveRequest")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage)
} else {
@ -213,6 +230,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = true
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
@ -228,6 +246,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
processException(t)
}
}
ioRequest.session.waitingForResponse = false
createTransaction()
}

View File

@ -233,6 +233,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val protocol = protocolFactory(otherParty)
val psm = createFiber(protocol)
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
if (sessionInit.firstPayload != null) {
session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload)
}
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(protocol, otherParty)] = session
updateCheckpoint(psm)
@ -400,7 +403,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
data class SessionInit(val initiatorSessionId: Long,
val initiatorParty: Party,
val protocolName: String,
val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage

View File

@ -7,6 +7,7 @@ import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days
import com.r3corda.core.map
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolStateMachine
@ -143,15 +144,15 @@ class TwoPartyTradeProtocolTests {
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
bobNode.pumpReceive(false)
bobNode.pumpReceive()
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive()
bobNode.pumpReceive()
aliceNode.pumpReceive()
bobNode.pumpReceive()
aliceNode.pumpReceive()
bobNode.pumpReceive()
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
@ -164,7 +165,7 @@ class TwoPartyTradeProtocolTests {
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
// She will wait around until Bob comes back.
assertThat(aliceNode.pumpReceive(false)).isNotNull()
assertThat(aliceNode.pumpReceive()).isNotNull()
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
// that Bob was waiting on before the reboot occurred.
@ -386,7 +387,7 @@ class TwoPartyTradeProtocolTests {
private data class RunResult(
// The buyer is not created immediately, only when the seller starts running
val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
val buyer: Future<ProtocolStateMachine<*>>,
val sellerResult: Future<SignedTransaction>,
val sellerId: StateMachineRunId
)
@ -394,7 +395,7 @@ class TwoPartyTradeProtocolTests {
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
}
}.map { it.psm }
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
val sellerResultFuture = aliceNode.smm.add(seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)

View File

@ -4,15 +4,15 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
@ -20,20 +20,25 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Observable
import java.util.*
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class StateMachineManagerTests {
val net = MockNetwork()
lateinit var node1: MockNode
lateinit var node2: MockNode
private val net = MockNetwork()
private val sessionTransfers = ArrayList<SessionTransfer>()
private lateinit var node1: MockNode
private lateinit var node2: MockNode
@Before
fun start() {
val nodes = net.createTwoNodes()
node1 = nodes.first
node2 = nodes.second
net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it }
net.runNetwork()
}
@ -44,14 +49,18 @@ class StateMachineManagerTests {
@Test
fun `newly added protocol is preserved on restart`() {
node1.smm.add(ProtocolWithoutCheckpoints())
val restoredProtocol = node1.restartAndGetRestoredProtocol<ProtocolWithoutCheckpoints>()
node1.smm.add(NoOpProtocol(nonTerminating = true))
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
assertThat(restoredProtocol.protocolStarted).isTrue()
}
@Test
fun `protocol can lazily use the serviceHub in its constructor`() {
val protocol = ProtocolWithLazyServiceHub()
val protocol = object : ProtocolLogic<Unit>() {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
node1.smm.add(protocol)
assertThat(protocol.lazyTime).isNotNull()
}
@ -62,19 +71,18 @@ class StateMachineManagerTests {
val payload = random63BitValue()
node1.smm.add(SendProtocol(payload, node2.info.legalIdentity))
// We push through just enough messages to get only the SessionData sent
// TODO We should be able to give runNetwork a predicate for when to stop
net.runNetwork(2)
// We push through just enough messages to get only the payload sent
node2.pumpReceive()
node2.stop()
net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `protocol added before network map does run after init`() {
val node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
val protocol = NoOpProtocol()
node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
@ -84,13 +92,13 @@ class StateMachineManagerTests {
@Test
fun `protocol added before network map will be init checkpointed`() {
var node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
val protocol = NoOpProtocol()
node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
val restoredProtocol = node3.getSingleProtocol<NoOpProtocol>().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
@ -101,17 +109,16 @@ class StateMachineManagerTests {
node3 = net.createNode(node1.info.address, forcedID = node3.id)
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty())
assertTrue(node3.smm.findStateMachines(NoOpProtocol::class.java).isEmpty())
}
@Test
fun `protocol loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue()
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.legalIdentity)
node2.smm.add(receiveProtocol) // Prepare checkpointed receive protocol
node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
node2.stop() // kill receiver
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@ -121,41 +128,36 @@ class StateMachineManagerTests {
val payload2 = random63BitValue()
var sentCount = 0
var receivedCount = 0
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
val node3 = net.createNode(node1.info.address)
net.runNetwork()
net.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
var secondProtocol: PingPongProtocol? = null
node3.services.registerProtocolInitiator(PingPongProtocol::class) {
val protocol = PingPongProtocol(it, payload2)
secondProtocol = protocol
protocol
}
val node3 = net.createNode(node1.info.address)
val secondProtocol = node3.initiateSingleShotProtocol(PingPongProtocol::class) { PingPongProtocol(it, payload2) }
net.runNetwork()
// Kick off first send and receive
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints().count())
assertEquals(1, node2.checkpointStorage.checkpoints().size)
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().size)
node2b.smm.executor.flush()
fut1.get()
val receivedCount = sessionTransfers.count { it.isPayloadTransfer }
// Check protocols completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
// can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol.get().receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
@Test
@ -171,6 +173,20 @@ class StateMachineManagerTests {
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
assertSessionTransfers(node2,
node1 sent sessionInit(node1, SendProtocol::class, payload) to node2,
node2 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node2
//There's no session end from the other protocols as they're manually suspended
)
assertSessionTransfers(node3,
node1 sent sessionInit(node1, SendProtocol::class, payload) to node3,
node3 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node3
//There's no session end from the other protocols as they're manually suspended
)
}
@Test
@ -183,13 +199,39 @@ class StateMachineManagerTests {
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
node1.smm.add(multiReceiveProtocol)
net.runNetwork(1) // session handshaking
// have the messages arrive in reverse order of receive
node3.pumpReceive(false)
node2.pumpReceive(false)
net.runNetwork() // pump remaining messages
net.runNetwork()
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
assertSessionTransfers(node2,
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1
)
assertSessionTransfers(node3,
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node3,
node3 sent sessionConfirm() to node1,
node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1
)
}
@Test
fun `both sides do a send as their first IO request`() {
node2.services.registerProtocolInitiator(PingPongProtocol::class) { PingPongProtocol(it, 20L) }
node1.smm.add(PingPongProtocol(node2.info.legalIdentity, 10L))
net.runNetwork()
assertSessionTransfers(
node1 sent sessionInit(node1, PingPongProtocol::class, 10L) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1,
node1 sent sessionEnd() to node2
)
}
@Test
@ -198,16 +240,17 @@ class StateMachineManagerTests {
val future = node1.smm.add(ReceiveThenSuspendProtocol(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
assertSessionTransfers(
node1 sent sessionInit(node1, ReceiveThenSuspendProtocol::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionEnd() to node1
)
}
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
return transfer.message.topicSession == StateMachineManager.sessionTopic
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol(
networkMapNode: MockNode? = null): P {
stop()
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first
}
@ -216,35 +259,66 @@ class StateMachineManagerTests {
return smm.findStateMachines(P::class.java).single()
}
private fun sessionInit(initiatorNode: MockNode, protocolMarker: KClass<*>, payload: Any? = null): SessionInit {
return SessionInit(0, initiatorNode.info.legalIdentity, protocolMarker.java.name, payload)
}
private fun sessionConfirm() = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionEnd() = SessionEnd(0)
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected)
}
private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) {
val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id }
assertThat(actualForNode).containsExactly(*expected)
}
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to"
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
val from = it.sender.myAddress.id
val message = it.message.data.deserialize<SessionMessage>()
val to = (it.recipients as InMemoryMessagingNetwork.Handle).id
SessionTransfer(from, sanitise(message), to)
}
}
private fun sanitise(message: SessionMessage): SessionMessage {
return when (message) {
is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0)
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0)
is SessionEnd -> message.copy(recipientSessionId = 0)
else -> message
}
}
private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message)
private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id)
private class NoOpProtocol(val nonTerminating: Boolean = false) : ProtocolLogic<Unit>() {
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false
@Suspendable
override fun call() {
protocolStarted = true
if (nonTerminating) {
Fiber.park()
}
}
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@Transient var protocolStarted = false
@Suspendable
override fun doCall() {
protocolStarted = true
}
}
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
@ -257,7 +331,7 @@ class StateMachineManagerTests {
}
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : ProtocolLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
@ -266,8 +340,10 @@ class StateMachineManagerTests {
@Transient var receivedPayloads: List<Any> = emptyList()
@Suspendable
override fun doCall() {
override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
println(receivedPayloads)
Fiber.park()
}
}
@ -279,7 +355,9 @@ class StateMachineManagerTests {
@Suspendable
override fun call() {
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
println("${psm.id} Received $receivedPayload")
receivedPayload2 = sendAndReceive<Long>(otherParty, payload + 1).unwrap { it }
println("${psm.id} Received $receivedPayload2")
}
}
@ -287,20 +365,4 @@ class StateMachineManagerTests {
override fun call(): Nothing = throw Exception()
}
/**
* A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after
* restart for testing checkpoint restoration. Store any results as @Transient fields.
*/
private abstract class NonTerminatingProtocol : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
doCall()
Fiber.park()
}
@Suspendable
abstract fun doCall()
}
}

View File

@ -12,6 +12,7 @@ import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.flatMap
import com.r3corda.core.map
import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
@ -111,7 +112,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
irs.fixedLeg.fixedRatePayer = node1.info.legalIdentity
irs.floatingLeg.floatingRatePayer = node2.info.legalIdentity
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture }
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap {
(it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture
}
showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0]))

View File

@ -11,6 +11,7 @@ import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days
import com.r3corda.core.flatMap
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
@ -51,7 +52,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) {
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
}.flatMap { it.resultFuture }
}.flatMap { (it.psm as ProtocolStateMachine<SignedTransaction>).resultFuture }
val sellerKey = seller.services.legalIdentityKey
val sellerProtocol = Seller(

View File

@ -12,11 +12,11 @@ import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import com.r3corda.node.services.statemachine.StateMachineManager.Change
import com.r3corda.node.utilities.AddOrRemove.ADD
import com.r3corda.testing.node.MockIdentityService
@ -138,20 +138,20 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
/**
* The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty
* protocol requests for it using [markerClass].
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request.
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachineImpl] created by the request.
*/
inline fun <R, reified P : ProtocolLogic<R>> AbstractNode.initiateSingleShotProtocol(
inline fun <reified P : ProtocolLogic<*>> AbstractNode.initiateSingleShotProtocol(
markerClass: KClass<out ProtocolLogic<*>>,
noinline protocolFactory: (Party) -> P): ListenableFuture<ProtocolStateMachine<R>> {
noinline protocolFactory: (Party) -> P): ListenableFuture<P> {
services.registerProtocolInitiator(markerClass, protocolFactory)
val future = SettableFuture.create<ProtocolStateMachine<R>>()
val future = SettableFuture.create<P>()
val subscriber = object : Subscriber<Change>() {
override fun onNext(change: Change) {
if (change.logic is P && change.addOrRemove == ADD) {
unsubscribe()
future.set(change.logic.psm as ProtocolStateMachine<R>)
future.set(change.logic as P)
}
}
override fun onError(e: Throwable) {

View File

@ -223,7 +223,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
override val myAddress: SingleMessageRecipient = handle
override val myAddress: Handle get() = handle
private val backgroundThread = if (manuallyPumped) null else
thread(isDaemon = true, name = "In-memory message dispatcher") {

View File

@ -1,5 +1,6 @@
package com.r3corda.testing.node
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures
import com.r3corda.core.crypto.Party
@ -27,6 +28,7 @@ import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.utilities.databaseTransaction
import org.slf4j.Logger
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyPair
@ -49,7 +51,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
private val threadPerNode: Boolean = false,
private val defaultFactory: Factory = MockNetwork.DefaultFactory) {
private var counter = 0
val filesystem = com.google.common.jimfs.Jimfs.newFileSystem(com.google.common.jimfs.Configuration.unix())
val filesystem: FileSystem = Jimfs.newFileSystem(unix())
val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped)
// A unique identifier for this network to segregate databases with the same nodeID but different networks.
@ -138,7 +140,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? {
fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? {
return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block)
}