Removed session IDs from the send and receive methods of ProtocolLogic and are now partially managed by HandshakeMessage

This commit is contained in:
Shams Asari 2016-09-13 17:37:42 +01:00
parent f314bab6c8
commit 8ea20dd0d2
32 changed files with 539 additions and 519 deletions

View File

@ -9,7 +9,6 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.TransactionBuilder
@ -58,19 +57,17 @@ object TwoPartyTradeProtocol {
class SellerTradeInfo( class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>, val assetForSale: StateAndRef<OwnableState>,
val price: Amount<Currency>, val price: Amount<Currency>,
val sellerOwnerKey: PublicKey, val sellerOwnerKey: PublicKey
val sessionID: Long
) )
class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable) val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherSide: Party, open class Seller(val otherParty: Party,
val notaryNode: NodeInfo, val notaryNode: NodeInfo,
val assetToSell: StateAndRef<OwnableState>, val assetToSell: StateAndRef<OwnableState>,
val price: Amount<Currency>, val price: Amount<Currency>,
val myKeyPair: KeyPair, val myKeyPair: KeyPair,
val buyerSessionID: Long,
override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() { override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() {
companion object { companion object {
@ -109,12 +106,10 @@ object TwoPartyTradeProtocol {
private fun receiveAndCheckProposedTransaction(): SignedTransaction { private fun receiveAndCheckProposedTransaction(): SignedTransaction {
progressTracker.currentStep = AWAITING_PROPOSAL progressTracker.currentStep = AWAITING_PROPOSAL
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol. // Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID) val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, buyerSessionID, sessionID, hello) val maybeSTX = sendAndReceive<SignedTransaction>(otherParty, hello)
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
@ -127,7 +122,7 @@ object TwoPartyTradeProtocol {
// Download and check all the things that this transaction depends on and verify it is contract-valid, // Download and check all the things that this transaction depends on and verify it is contract-valid,
// even though it is missing signatures. // even though it is missing signatures.
subProtocol(ResolveTransactionsProtocol(wtx, otherSide)) subProtocol(ResolveTransactionsProtocol(wtx, otherParty))
if (wtx.outputs.map { it.data }.sumCashBy(myKeyPair.public).withoutIssuer() != price) if (wtx.outputs.map { it.data }.sumCashBy(myKeyPair.public).withoutIssuer() != price)
throw IllegalArgumentException("Transaction is not sending us the right amount of cash") throw IllegalArgumentException("Transaction is not sending us the right amount of cash")
@ -159,16 +154,15 @@ object TwoPartyTradeProtocol {
logger.trace { "Built finished transaction, sending back to secondary!" } logger.trace { "Built finished transaction, sending back to secondary!" }
send(otherSide, buyerSessionID, SignaturesFromSeller(ourSignature, notarySignature)) send(otherParty, SignaturesFromSeller(ourSignature, notarySignature))
return fullySigned return fullySigned
} }
} }
open class Buyer(val otherSide: Party, open class Buyer(val otherParty: Party,
val notary: Party, val notary: Party,
val acceptablePrice: Amount<Currency>, val acceptablePrice: Amount<Currency>,
val typeToBuy: Class<out OwnableState>, val typeToBuy: Class<out OwnableState>) : ProtocolLogic<SignedTransaction>() {
val sessionID: Long) : ProtocolLogic<SignedTransaction>() {
object RECEIVING : ProgressTracker.Step("Waiting for seller trading info") object RECEIVING : ProgressTracker.Step("Waiting for seller trading info")
@ -189,7 +183,7 @@ object TwoPartyTradeProtocol {
val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest) val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest)
val stx = signWithOurKeys(cashSigningPubKeys, ptx) val stx = signWithOurKeys(cashSigningPubKeys, ptx)
val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID) val signatures = swapSignaturesWithSeller(stx)
logger.trace { "Got signatures from seller, verifying ... " } logger.trace { "Got signatures from seller, verifying ... " }
@ -204,7 +198,7 @@ object TwoPartyTradeProtocol {
private fun receiveAndValidateTradeRequest(): SellerTradeInfo { private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
progressTracker.currentStep = RECEIVING progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID. // Wait for a trade request to come in on our pre-provided session ID.
val maybeTradeRequest = receive<SellerTradeInfo>(sessionID) val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
maybeTradeRequest.unwrap { maybeTradeRequest.unwrap {
@ -213,8 +207,6 @@ object TwoPartyTradeProtocol {
val assetTypeName = asset.javaClass.name val assetTypeName = asset.javaClass.name
logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" } logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" }
// Check the start message for acceptability.
check(it.sessionID > 0)
if (it.price > acceptablePrice) if (it.price > acceptablePrice)
throw UnacceptablePriceException(it.price) throw UnacceptablePriceException(it.price)
if (!typeToBuy.isInstance(asset)) if (!typeToBuy.isInstance(asset))
@ -222,20 +214,20 @@ object TwoPartyTradeProtocol {
// Check the transaction that contains the state which is being resolved. // Check the transaction that contains the state which is being resolved.
// We only have a hash here, so if we don't know it already, we have to ask for it. // We only have a hash here, so if we don't know it already, we have to ask for it.
subProtocol(ResolveTransactionsProtocol(setOf(it.assetForSale.ref.txhash), otherSide)) subProtocol(ResolveTransactionsProtocol(setOf(it.assetForSale.ref.txhash), otherParty))
return it return it
} }
} }
@Suspendable @Suspendable
private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller { private fun swapSignaturesWithSeller(stx: SignedTransaction): SignaturesFromSeller {
progressTracker.currentStep = SWAPPING_SIGNATURES progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to seller" } logger.trace { "Sending partially signed transaction to seller" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx. // TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromSeller>(otherSide, theirSessionID, sessionID, stx).unwrap { it } return sendAndReceive<SignaturesFromSeller>(otherParty, stx).unwrap { it }
} }
private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction { private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {

View File

@ -2,6 +2,7 @@ package com.r3corda.core
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.io.ByteStreams import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
@ -15,6 +16,7 @@ import java.nio.file.Path
import java.time.Duration import java.time.Duration
import java.time.temporal.Temporal import java.time.temporal.Temporal
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.Future
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
@ -67,6 +69,7 @@ fun <T> ListenableFuture<T>.failure(executor: Executor, body: (Throwable) -> Uni
} }
} }
infix fun <F, T> Future<F>.map(mapper: (F) -> T): Future<T> = Futures.lazyTransform(this) { mapper(it!!) }
infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) }

View File

@ -3,9 +3,13 @@ package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.debug
import com.r3corda.protocols.HandshakeMessage
import org.slf4j.Logger import org.slf4j.Logger
import java.util.*
/** /**
* A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you * A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you
@ -45,32 +49,78 @@ abstract class ProtocolLogic<out T> {
*/ */
protected abstract val topic: String protected abstract val topic: String
private val sessions = HashMap<Party, Session>()
/**
* If a node receives a [HandshakeMessage] it needs to call this method on the initiated receipt protocol to enable
* communication between it and the sender protocol. Calling this method, and other initiation steps, are already
* handled by AbstractNodeService.addProtocolHandler.
*/
fun registerSession(receivedHandshake: HandshakeMessage) {
// Note that the send and receive session IDs are swapped
addSession(receivedHandshake.replyToParty, receivedHandshake.receiveSessionID, receivedHandshake.sendSessionID)
}
// Kotlin helpers that allow the use of generic types. // Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(destination: Party, inline fun <reified T : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<T> {
sessionIDForSend: Long, return sendAndReceive(otherParty, payload, T::class.java)
sessionIDForReceive: Long,
payload: Any): UntrustworthyData<T> {
return psm.sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, payload, T::class.java)
} }
inline fun <reified T : Any> receive(sessionIDForReceive: Long): UntrustworthyData<T> { @Suspendable
return receive(sessionIDForReceive, T::class.java) fun <T : Any> sendAndReceive(otherParty: Party, payload: Any, receiveType: Class<T>): UntrustworthyData<T> {
val sendSessionId = getSendSessionId(otherParty, payload)
val receiveSessionId = getReceiveSessionId(otherParty)
return psm.sendAndReceive(topic, otherParty, sendSessionId, receiveSessionId, payload, receiveType)
} }
@Suspendable fun <T : Any> receive(sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> { inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(otherParty, T::class.java)
return psm.receive(topic, sessionIDForReceive, receiveType)
@Suspendable
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>): UntrustworthyData<T> {
return psm.receive(topic, getReceiveSessionId(otherParty), receiveType)
} }
@Suspendable fun send(destination: Party, sessionID: Long, payload: Any) { @Suspendable
psm.send(topic, destination, sessionID, payload) fun send(otherParty: Party, payload: Any) {
psm.send(topic, otherParty, getSendSessionId(otherParty, payload), payload)
}
private fun addSession(party: Party, sendSesssionId: Long, receiveSessionId: Long) {
if (party in sessions) {
logger.debug { "Existing session with party $party to be overwritten by new one" }
}
sessions[party] = Session(sendSesssionId, receiveSessionId)
}
private fun getSendSessionId(otherParty: Party, payload: Any): Long {
return if (payload is HandshakeMessage) {
addSession(otherParty, payload.sendSessionID, payload.receiveSessionID)
DEFAULT_SESSION_ID
} else {
sessions[otherParty]?.sendSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
}
private fun getReceiveSessionId(otherParty: Party): Long {
return sessions[otherParty]?.receiveSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
} }
/** /**
* Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the * Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the
* [ProtocolStateMachine] and then calling the [call] method. * [ProtocolStateMachine] and then calling the [call] method.
* @param inheritParentSessions In certain situations the subprotocol needs to inherit and use the same open
* sessions of the parent. However in most cases this is not desirable as it prevents the subprotocol from
* communicating with the same party on a different topic. For this reason the default value is false.
*/ */
@Suspendable fun <R> subProtocol(subLogic: ProtocolLogic<R>): R { @JvmOverloads
@Suspendable
fun <R> subProtocol(subLogic: ProtocolLogic<R>, inheritParentSessions: Boolean = false): R {
subLogic.psm = psm subLogic.psm = psm
if (inheritParentSessions) {
subLogic.sessions.putAll(sessions)
}
maybeWireUpProgressTracking(subLogic) maybeWireUpProgressTracking(subLogic)
val r = subLogic.call() val r = subLogic.call()
// It's easy to forget this when writing protocols so we just step it to the DONE state when it completes. // It's easy to forget this when writing protocols so we just step it to the DONE state when it completes.
@ -106,4 +156,6 @@ abstract class ProtocolLogic<out T> {
@Suspendable @Suspendable
abstract fun call(): T abstract fun call(): T
private data class Session(val sendSessionId: Long, val receiveSessionId: Long)
} }

View File

@ -7,8 +7,6 @@ import com.r3corda.core.contracts.StateRef
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
@ -36,9 +34,9 @@ abstract class AbstractStateReplacementProtocol<T> {
val stx: SignedTransaction val stx: SignedTransaction
} }
data class Handshake(val sessionIdForSend: Long, data class Handshake(override val replyToParty: Party,
override val replyToParty: Party, override val sendSessionID: Long = random63BitValue(),
override val sessionID: Long) : PartyRequestMessage override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>, abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>,
val modification: T, val modification: T,
@ -77,36 +75,31 @@ abstract class AbstractStateReplacementProtocol<T> {
@Suspendable @Suspendable
private fun collectSignatures(participants: List<PublicKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> { private fun collectSignatures(participants: List<PublicKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> {
val sessions = mutableMapOf<NodeInfo, Long>() val parties = participants.map {
val participantSignatures = participants.map {
val participantNode = serviceHub.networkMapCache.getNodeByPublicKey(it) ?: val participantNode = serviceHub.networkMapCache.getNodeByPublicKey(it) ?:
throw IllegalStateException("Participant $it to state $originalState not found on the network") throw IllegalStateException("Participant $it to state $originalState not found on the network")
val sessionIdForSend = random63BitValue() participantNode.identity
sessions[participantNode] = sessionIdForSend
getParticipantSignature(participantNode, stx, sessionIdForSend)
} }
val participantSignatures = parties.map { getParticipantSignature(it, stx) }
val allSignatures = participantSignatures + getNotarySignature(stx) val allSignatures = participantSignatures + getNotarySignature(stx)
sessions.forEach { send(it.key.identity, it.value, allSignatures) } parties.forEach { send(it, allSignatures) }
return allSignatures return allSignatures
} }
@Suspendable @Suspendable
private fun getParticipantSignature(node: NodeInfo, stx: SignedTransaction, sessionIdForSend: Long): DigitalSignature.WithKey { private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey {
val sessionIdForReceive = random63BitValue()
val proposal = assembleProposal(originalState.ref, modification, stx) val proposal = assembleProposal(originalState.ref, modification, stx)
val handshake = Handshake(sessionIdForSend, serviceHub.storageService.myLegalIdentity, sessionIdForReceive) send(party, Handshake(serviceHub.storageService.myLegalIdentity))
sendAndReceive<Ack>(node.identity, 0, sessionIdForReceive, handshake)
val response = sendAndReceive<Result>(node.identity, sessionIdForSend, sessionIdForReceive, proposal) val response = sendAndReceive<Result>(party, proposal)
val participantSignature = response.unwrap { val participantSignature = response.unwrap {
if (it.sig == null) throw StateReplacementException(it.error!!) if (it.sig == null) throw StateReplacementException(it.error!!)
else { else {
check(it.sig.by == node.identity.owningKey) { "Not signed by the required participant" } check(it.sig.by == party.owningKey) { "Not signed by the required participant" }
it.sig.verifyWithECDSA(stx.txBits) it.sig.verifyWithECDSA(stx.txBits)
it.sig it.sig
} }
@ -123,9 +116,7 @@ abstract class AbstractStateReplacementProtocol<T> {
} }
abstract class Acceptor<T>(val otherSide: Party, abstract class Acceptor<T>(val otherSide: Party,
val sessionIdForSend: Long, override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
val sessionIdForReceive: Long,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object { companion object {
object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal") object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal")
@ -140,7 +131,7 @@ abstract class AbstractStateReplacementProtocol<T> {
@Suspendable @Suspendable
override fun call() { override fun call() {
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
val maybeProposal: UntrustworthyData<Proposal<T>> = receive(sessionIdForReceive) val maybeProposal: UntrustworthyData<Proposal<T>> = receive(otherSide)
try { try {
val stx: SignedTransaction = maybeProposal.unwrap { verifyProposal(maybeProposal).stx } val stx: SignedTransaction = maybeProposal.unwrap { verifyProposal(maybeProposal).stx }
verifyTx(stx) verifyTx(stx)
@ -163,7 +154,7 @@ abstract class AbstractStateReplacementProtocol<T> {
val mySignature = sign(stx) val mySignature = sign(stx)
val response = Result.noError(mySignature) val response = Result.noError(mySignature)
val swapSignatures = sendAndReceive<List<DigitalSignature.WithKey>>(otherSide, sessionIdForSend, sessionIdForReceive, response) val swapSignatures = sendAndReceive<List<DigitalSignature.WithKey>>(otherSide, response)
// TODO: This step should not be necessary, as signatures are re-checked in verifySignatures. // TODO: This step should not be necessary, as signatures are re-checked in verifySignatures.
val allSignatures = swapSignatures.unwrap { signatures -> val allSignatures = swapSignatures.unwrap { signatures ->
@ -180,7 +171,7 @@ abstract class AbstractStateReplacementProtocol<T> {
private fun reject(e: StateReplacementRefused) { private fun reject(e: StateReplacementRefused) {
progressTracker.currentStep = REJECTING progressTracker.currentStep = REJECTING
val response = Result.withError(e) val response = Result.withError(e)
send(otherSide, sessionIdForSend, response) send(otherSide, response)
} }
/** /**

View File

@ -2,12 +2,10 @@ package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.ClientToServiceCommand import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.serialize import com.r3corda.core.transactions.SignedTransaction
import java.util.*
/** /**
@ -33,12 +31,11 @@ class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
override val topic: String = TOPIC override val topic: String = TOPIC
data class NotifyTxRequestMessage( data class NotifyTxRequestMessage(val tx: SignedTransaction,
val tx: SignedTransaction, val events: Set<ClientToServiceCommand>,
val events: Set<ClientToServiceCommand>, override val replyToParty: Party,
override val replyToParty: Party, override val sendSessionID: Long = random63BitValue(),
override val sessionID: Long override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
) : PartyRequestMessage
@Suspendable @Suspendable
override fun call() { override fun call() {
@ -48,9 +45,11 @@ class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
// TODO: Messaging layer should handle this broadcast for us (although we need to not be sending // TODO: Messaging layer should handle this broadcast for us (although we need to not be sending
// session ID, for that to work, as well). // session ID, for that to work, as well).
participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant -> participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant ->
val sessionID = random63BitValue() val msg = NotifyTxRequestMessage(
val msg = NotifyTxRequestMessage(notarisedTransaction, events, serviceHub.storageService.myLegalIdentity, sessionID) notarisedTransaction,
send(participant, 0, msg) events,
serviceHub.storageService.myLegalIdentity)
send(participant, msg)
} }
} }
} }

View File

@ -7,6 +7,8 @@ import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch
import com.r3corda.protocols.FetchDataProtocol.HashNotFound
import java.util.* import java.util.*
/** /**
@ -33,7 +35,10 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
class HashNotFound(val requested: SecureHash) : BadAnswer() class HashNotFound(val requested: SecureHash) : BadAnswer()
class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer() class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer()
data class Request(val hashes: List<SecureHash>, override val replyToParty: Party, override val sessionID: Long) : PartyRequestMessage data class Request(val hashes: List<SecureHash>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>) data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>)
@Suspendable @Suspendable
@ -46,10 +51,9 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
} else { } else {
logger.trace("Requesting ${toFetch.size} dependency(s) for verification") logger.trace("Requesting ${toFetch.size} dependency(s) for verification")
val sid = random63BitValue() val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity)
val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity, sid)
// TODO: Support "large message" response streaming so response sizes are not limited by RAM. // TODO: Support "large message" response streaming so response sizes are not limited by RAM.
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, 0, sid, fetchReq) val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, fetchReq)
// Check for a buggy/malicious peer answering with something that we didn't ask for. // Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(maybeItems, toFetch) val downloaded = validateFetchResponse(maybeItems, toFetch)
maybeWriteToDisk(downloaded) maybeWriteToDisk(downloaded)

View File

@ -53,10 +53,8 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
} }
class Acceptor(otherSide: Party, class Acceptor(otherSide: Party,
sessionIdForSend: Long,
sessionIdForReceive: Long,
override val progressTracker: ProgressTracker = tracker()) override val progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Acceptor<Party>(otherSide, sessionIdForSend, sessionIdForReceive) { : AbstractStateReplacementProtocol.Acceptor<Party>(otherSide) {
override val topic: String get() = TOPIC override val topic: String get() = TOPIC

View File

@ -1,10 +1,6 @@
package com.r3corda.protocols package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.contracts.Timestamp
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SignedData import com.r3corda.core.crypto.SignedData
@ -13,11 +9,12 @@ import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessException import com.r3corda.core.node.services.UniquenessException
import com.r3corda.core.node.services.UniquenessProvider import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.core.noneOrSingle
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import java.security.PublicKey import java.security.PublicKey
@ -56,14 +53,10 @@ object NotaryProtocol {
notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary")
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" } check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" }
val sendSessionID = random63BitValue() sendAndReceive<Ack>(notaryParty, Handshake(serviceHub.storageService.myLegalIdentity))
val receiveSessionID = random63BitValue()
val handshake = Handshake(serviceHub.storageService.myLegalIdentity, sendSessionID, receiveSessionID)
sendAndReceive<Ack>(notaryParty, 0, receiveSessionID, handshake)
val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity) val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity)
val response = sendAndReceive<Result>(notaryParty, sendSessionID, receiveSessionID, request) val response = sendAndReceive<Result>(notaryParty, request)
val notaryResult = validateResponse(response) val notaryResult = validateResponse(response)
return notaryResult.sig ?: throw NotaryException(notaryResult.error!!) return notaryResult.sig ?: throw NotaryException(notaryResult.error!!)
@ -96,8 +89,6 @@ object NotaryProtocol {
* TODO: the notary service should only be able to see timestamp commands and inputs * TODO: the notary service should only be able to see timestamp commands and inputs
*/ */
open class Service(val otherSide: Party, open class Service(val otherSide: Party,
val sendSessionID: Long,
val receiveSessionID: Long,
val timestampChecker: TimestampChecker, val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() { val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() {
@ -105,7 +96,7 @@ object NotaryProtocol {
@Suspendable @Suspendable
override fun call() { override fun call() {
val (stx, reqIdentity) = receive<SignRequest>(receiveSessionID).unwrap { it } val (stx, reqIdentity) = sendAndReceive<SignRequest>(otherSide, Ack).unwrap { it }
val wtx = stx.tx val wtx = stx.tx
val result = try { val result = try {
@ -119,7 +110,7 @@ object NotaryProtocol {
Result.withError(e.error) Result.withError(e.error)
} }
send(otherSide, sendSessionID, result) send(otherSide, result)
} }
private fun validateTimestamp(tx: WireTransaction) { private fun validateTimestamp(tx: WireTransaction) {
@ -157,10 +148,9 @@ object NotaryProtocol {
} }
} }
data class Handshake( data class Handshake(override val replyToParty: Party,
override val replyToParty: Party, override val sendSessionID: Long = random63BitValue(),
val sendSessionID: Long, override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
override val sessionID: Long) : PartyRequestMessage
/** TODO: The caller must authenticate instead of just specifying its identity */ /** TODO: The caller must authenticate instead of just specifying its identity */
data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party) data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party)
@ -174,19 +164,15 @@ object NotaryProtocol {
interface Factory { interface Factory {
fun create(otherSide: Party, fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker, timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service uniquenessProvider: UniquenessProvider): Service
} }
object DefaultFactory : Factory { object DefaultFactory : Factory {
override fun create(otherSide: Party, override fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker, timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service { uniquenessProvider: UniquenessProvider): Service {
return Service(otherSide, sendSessionID, receiveSessionID, timestampChecker, uniquenessProvider) return Service(otherSide, timestampChecker, uniquenessProvider)
} }
} }
} }

View File

@ -3,12 +3,12 @@ package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.Fix import com.r3corda.core.contracts.Fix
import com.r3corda.core.contracts.FixOf import com.r3corda.core.contracts.FixOf
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.RatesFixProtocol.FixOutOfRange import com.r3corda.protocols.RatesFixProtocol.FixOutOfRange
@ -47,8 +47,16 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount") class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>, override val replyToParty: Party, override val sessionID: Long, val deadline: Instant) : PartyRequestMessage data class QueryRequest(val queries: List<FixOf>,
data class SignRequest(val tx: WireTransaction, override val replyToParty: Party, override val sessionID: Long) : PartyRequestMessage val deadline: Instant,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class SignRequest(val tx: WireTransaction,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
@Suspendable @Suspendable
override fun call() { override fun call() {
@ -80,10 +88,9 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
@Suspendable @Suspendable
private fun sign(): DigitalSignature.LegallyIdentifiable { private fun sign(): DigitalSignature.LegallyIdentifiable {
val sessionID = random63BitValue()
val wtx = tx.toWireTransaction() val wtx = tx.toWireTransaction()
val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity, sessionID) val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, 0, sessionID, req) val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, req)
return resp.unwrap { sig -> return resp.unwrap { sig ->
check(sig.signer == oracle) check(sig.signer == oracle)
@ -94,11 +101,10 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
@Suspendable @Suspendable
private fun query(): Fix { private fun query(): Fix {
val sessionID = random63BitValue()
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
val req = QueryRequest(listOf(fixOf), serviceHub.storageService.myLegalIdentity, sessionID, deadline) val req = QueryRequest(listOf(fixOf), deadline, serviceHub.storageService.myLegalIdentity)
// TODO: add deadline to receive // TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, 0, sessionID, req) val resp = sendAndReceive<ArrayList<Fix>>(oracle, req)
return resp.unwrap { return resp.unwrap {
val fix = it.first() val fix = it.first()

View File

@ -32,4 +32,19 @@ interface PartyRequestMessage : ServiceRequestMessage {
override fun getReplyTo(networkMapCache: NetworkMapCache): MessageRecipients { override fun getReplyTo(networkMapCache: NetworkMapCache): MessageRecipients {
return networkMapCache.partyNodes.single { it.identity == replyToParty }.address return networkMapCache.partyNodes.single { it.identity == replyToParty }.address
} }
}
/**
* A Handshake message is sent to initiate communication between two protocol instances. It contains the two session IDs
* the two protocols will need to communicate.
* Note: This is a temperary interface and will be removed once the protocol session work is implemented.
*/
interface HandshakeMessage : PartyRequestMessage {
val sendSessionID: Long
val receiveSessionID: Long
@Deprecated("sessionID functions as receiveSessionID but it's recommended to use the later for clarity",
replaceWith = ReplaceWith("receiveSessionID"))
override val sessionID: Long get() = receiveSessionID
} }

View File

@ -7,7 +7,6 @@ import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
@ -48,11 +47,7 @@ object TwoPartyDealProtocol {
} }
// This object is serialised to the network and is the first protocol message the seller sends to the buyer. // This object is serialised to the network and is the first protocol message the seller sends to the buyer.
data class Handshake<out T>( data class Handshake<out T>(val payload: T, val publicKey: PublicKey)
val payload: T,
val publicKey: PublicKey,
val sessionID: Long
)
class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable) class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable)
@ -80,19 +75,16 @@ object TwoPartyDealProtocol {
abstract val payload: U abstract val payload: U
abstract val notaryNode: NodeInfo abstract val notaryNode: NodeInfo
abstract val otherSide: Party abstract val otherParty: Party
abstract val otherSessionID: Long
abstract val myKeyPair: KeyPair abstract val myKeyPair: KeyPair
@Suspendable @Suspendable
fun getPartialTransaction(): UntrustworthyData<SignedTransaction> { fun getPartialTransaction(): UntrustworthyData<SignedTransaction> {
progressTracker.currentStep = AWAITING_PROPOSAL progressTracker.currentStep = AWAITING_PROPOSAL
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol. // Make the first message we'll send to kick off the protocol.
val hello = Handshake(payload, myKeyPair.public, sessionID) val hello = Handshake(payload, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, otherSessionID, sessionID, hello) val maybeSTX = sendAndReceive<SignedTransaction>(otherParty, hello)
return maybeSTX return maybeSTX
} }
@ -132,7 +124,7 @@ object TwoPartyDealProtocol {
// Download and check all the transactions that this transaction depends on, but do not check this // Download and check all the transactions that this transaction depends on, but do not check this
// transaction itself. // transaction itself.
val dependencyTxIDs = stx.tx.inputs.map { it.txhash }.toSet() val dependencyTxIDs = stx.tx.inputs.map { it.txhash }.toSet()
subProtocol(ResolveTransactionsProtocol(dependencyTxIDs, otherSide)) subProtocol(ResolveTransactionsProtocol(dependencyTxIDs, otherParty))
} }
@Suspendable @Suspendable
@ -156,7 +148,7 @@ object TwoPartyDealProtocol {
if (regulators.isNotEmpty()) { if (regulators.isNotEmpty()) {
// Copy the transaction to every regulator in the network. This is obviously completely bogus, it's // Copy the transaction to every regulator in the network. This is obviously completely bogus, it's
// just for demo purposes. // just for demo purposes.
regulators.forEach { send(it.identity, DEFAULT_SESSION_ID, fullySigned) } regulators.forEach { send(it.identity, fullySigned) }
} }
return fullySigned return fullySigned
@ -181,7 +173,7 @@ object TwoPartyDealProtocol {
logger.trace { "Built finished transaction, sending back to other party!" } logger.trace { "Built finished transaction, sending back to other party!" }
send(otherSide, otherSessionID, SignaturesFromPrimary(ourSignature, notarySignature)) send(otherParty, SignaturesFromPrimary(ourSignature, notarySignature))
return fullySigned return fullySigned
} }
} }
@ -207,8 +199,7 @@ object TwoPartyDealProtocol {
override val topic: String get() = DEAL_TOPIC override val topic: String get() = DEAL_TOPIC
abstract val otherSide: Party abstract val otherParty: Party
abstract val sessionID: Long
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
@ -218,7 +209,7 @@ object TwoPartyDealProtocol {
val (ptx, additionalSigningPubKeys) = assembleSharedTX(handshake) val (ptx, additionalSigningPubKeys) = assembleSharedTX(handshake)
val stx = signWithOurKeys(additionalSigningPubKeys, ptx) val stx = signWithOurKeys(additionalSigningPubKeys, ptx)
val signatures = swapSignaturesWithPrimary(stx, handshake.sessionID) val signatures = swapSignaturesWithPrimary(stx)
logger.trace { "Got signatures from other party, verifying ... " } logger.trace { "Got signatures from other party, verifying ... " }
@ -238,7 +229,7 @@ object TwoPartyDealProtocol {
private fun receiveAndValidateHandshake(): Handshake<U> { private fun receiveAndValidateHandshake(): Handshake<U> {
progressTracker.currentStep = RECEIVING progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID. // Wait for a trade request to come in on our pre-provided session ID.
val handshake = receive<Handshake<U>>(sessionID) val handshake = receive<Handshake<U>>(otherParty)
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
handshake.unwrap { handshake.unwrap {
@ -247,13 +238,13 @@ object TwoPartyDealProtocol {
} }
@Suspendable @Suspendable
private fun swapSignaturesWithPrimary(stx: SignedTransaction, theirSessionID: Long): SignaturesFromPrimary { private fun swapSignaturesWithPrimary(stx: SignedTransaction): SignaturesFromPrimary {
progressTracker.currentStep = SWAPPING_SIGNATURES progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to other party" } logger.trace { "Sending partially signed transaction to other party" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx. // TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromPrimary>(otherSide, theirSessionID, sessionID, stx).unwrap { it } return sendAndReceive<SignaturesFromPrimary>(otherParty, stx).unwrap { it }
} }
private fun signWithOurKeys(signingPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction { private fun signWithOurKeys(signingPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
@ -273,11 +264,10 @@ object TwoPartyDealProtocol {
/** /**
* One side of the protocol for inserting a pre-agreed deal. * One side of the protocol for inserting a pre-agreed deal.
*/ */
open class Instigator<out T : DealState>(override val otherSide: Party, open class Instigator<out T : DealState>(override val otherParty: Party,
val notary: Party, val notary: Party,
override val payload: T, override val payload: T,
override val myKeyPair: KeyPair, override val myKeyPair: KeyPair,
override val otherSessionID: Long,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<T>() { override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<T>() {
override val notaryNode: NodeInfo get() = override val notaryNode: NodeInfo get() =
@ -287,10 +277,9 @@ object TwoPartyDealProtocol {
/** /**
* One side of the protocol for inserting a pre-agreed deal. * One side of the protocol for inserting a pre-agreed deal.
*/ */
open class Acceptor<T : DealState>(override val otherSide: Party, open class Acceptor<T : DealState>(override val otherParty: Party,
val notary: Party, val notary: Party,
val dealToBuy: T, val dealToBuy: T,
override val sessionID: Long,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<T>() { override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<T>() {
override fun validateHandshake(handshake: Handshake<T>): Handshake<T> { override fun validateHandshake(handshake: Handshake<T>): Handshake<T> {
@ -299,8 +288,6 @@ object TwoPartyDealProtocol {
val otherKey = handshake.publicKey val otherKey = handshake.publicKey
logger.trace { "Got deal request for: ${handshake.payload.ref}" } logger.trace { "Got deal request for: ${handshake.payload.ref}" }
// Check the start message for acceptability.
check(handshake.sessionID > 0)
check(dealToBuy == deal) check(dealToBuy == deal)
// We need to substitute in the new public keys for the Parties // We need to substitute in the new public keys for the Parties
@ -335,11 +322,9 @@ object TwoPartyDealProtocol {
* of the protocol that is run by the party with the fixed leg of swap deal, which is the basis for deciding * of the protocol that is run by the party with the fixed leg of swap deal, which is the basis for deciding
* who does what in the protocol. * who does what in the protocol.
*/ */
class Fixer(val initiation: FixingSessionInitiation, override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() { class Fixer(override val otherParty: Party,
val oracleType: ServiceType,
override val sessionID: Long get() = initiation.sessionID override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() {
override val otherSide: Party get() = initiation.sender
private lateinit var txState: TransactionState<*> private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState private lateinit var deal: FixableDealState
@ -347,16 +332,12 @@ object TwoPartyDealProtocol {
override fun validateHandshake(handshake: Handshake<StateRef>): Handshake<StateRef> { override fun validateHandshake(handshake: Handshake<StateRef>): Handshake<StateRef> {
logger.trace { "Got fixing request for: ${handshake.payload}" } logger.trace { "Got fixing request for: ${handshake.payload}" }
// Check the handshake and initiation for acceptability.
check(handshake.sessionID > 0)
txState = serviceHub.loadState(handshake.payload) txState = serviceHub.loadState(handshake.payload)
deal = txState.data as FixableDealState deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it. // validate the party that initiated is the one on the deal and that the recipient corresponds with it.
// TODO: this is in no way secure and will be replaced by general session initiation logic in the future // TODO: this is in no way secure and will be replaced by general session initiation logic in the future
val myName = serviceHub.storageService.myLegalIdentity.name val myName = serviceHub.storageService.myLegalIdentity.name
val otherParty = deal.parties.filter { it.name != myName }.single()
check(otherParty == initiation.party)
// Also check we are one of the parties // Also check we are one of the parties
deal.parties.filter { it.name == myName }.single() deal.parties.filter { it.name == myName }.single()
@ -376,7 +357,7 @@ object TwoPartyDealProtocol {
val ptx = TransactionType.General.Builder(txState.notary) val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(initiation.oracleType).first() val oracle = serviceHub.networkMapCache.get(oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) { val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable @Suspendable
@ -401,7 +382,6 @@ object TwoPartyDealProtocol {
* does what in the protocol. * does what in the protocol.
*/ */
class Floater(override val payload: StateRef, class Floater(override val payload: StateRef,
override val otherSessionID: Long,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<StateRef>() { override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<StateRef>() {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty { internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
@ -415,8 +395,8 @@ object TwoPartyDealProtocol {
return serviceHub.keyManagementService.toKeyPair(publicKey) return serviceHub.keyManagementService.toKeyPair(publicKey)
} }
override val otherSide: Party get() { override val otherParty: Party get() {
// TODO: what happens if there's no node? Move to messaging taking Party and then handled in messaging layer // TODO otherParty is sortedParties[1] from FixingRoleDecider.call and so can be passed in as a c'tor param
val myName = serviceHub.storageService.myLegalIdentity.name val myName = serviceHub.storageService.myLegalIdentity.name
return dealToFix.state.data.parties.filter { it.name != myName }.single() return dealToFix.state.data.parties.filter { it.name != myName }.single()
} }
@ -427,7 +407,11 @@ object TwoPartyDealProtocol {
/** Used to set up the session between [Floater] and [Fixer] */ /** Used to set up the session between [Floater] and [Fixer] */
data class FixingSessionInitiation(val sessionID: Long, val party: Party, val sender: Party, val timeout: Duration, val oracleType: ServiceType) data class FixingSessionInitiation(val timeout: Duration,
val oracleType: ServiceType,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/** /**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing. * This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
@ -459,18 +443,17 @@ object TwoPartyDealProtocol {
val sortedParties = fixableDeal.parties.sortedBy { it.name } val sortedParties = fixableDeal.parties.sortedBy { it.name }
val oracleType = fixableDeal.oracleType val oracleType = fixableDeal.oracleType
if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) { if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) {
// Generate sessionID val initation = FixingSessionInitiation(
val sessionID = random63BitValue() timeout,
val initation = FixingSessionInitiation(sessionID, sortedParties[0], serviceHub.storageService.myLegalIdentity, timeout, oracleType) oracleType,
serviceHub.storageService.myLegalIdentity)
// Send initiation to other side to launch one side of the fixing protocol (the Fixer). // Send initiation to other side to launch one side of the fixing protocol (the Fixer).
send(sortedParties[1], DEFAULT_SESSION_ID, initation) send(sortedParties[1], initation)
// Then start the other side of the fixing protocol. // Then start the other side of the fixing protocol.
val protocol = Floater(ref, sessionID) subProtocol(Floater(ref), inheritParentSessions = true)
subProtocol(protocol)
} }
} }
} }
} }

View File

@ -16,10 +16,10 @@ import java.security.SignatureException
* indeed valid. * indeed valid.
*/ */
class ValidatingNotaryProtocol(otherSide: Party, class ValidatingNotaryProtocol(otherSide: Party,
sessionIdForSend: Long,
sessionIdForReceive: Long,
timestampChecker: TimestampChecker, timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryProtocol.Service(otherSide, sessionIdForSend, sessionIdForReceive, timestampChecker, uniquenessProvider) { uniquenessProvider: UniquenessProvider) :
NotaryProtocol.Service(otherSide, timestampChecker, uniquenessProvider) {
@Suspendable @Suspendable
override fun beforeCommit(stx: SignedTransaction, reqIdentity: Party) { override fun beforeCommit(stx: SignedTransaction, reqIdentity: Party) {
try { try {

View File

@ -24,7 +24,8 @@ class BroadcastTransactionProtocolTest {
tx = SignedTransactionGenerator().generate(random, status), tx = SignedTransactionGenerator().generate(random, status),
events = setOf(), events = setOf(),
replyToParty = PartyGenerator().generate(random, status), replyToParty = PartyGenerator().generate(random, status),
sessionID = random.nextLong() sendSessionID = random.nextLong(),
receiveSessionID = random.nextLong()
) )
} }
} }

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services package com.r3corda.node.services
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange { object NotaryChange {
class Plugin : CordaPluginRegistry() { class Plugin : CordaPluginRegistry() {
@ -19,18 +18,9 @@ object NotaryChange {
*/ */
class Service(services: ServiceHubInternal) : AbstractNodeService(services) { class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init { init {
addMessageHandler(NotaryChangeProtocol.TOPIC, addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
{ req: AbstractStateReplacementProtocol.Handshake -> handleChangeNotaryRequest(req) } NotaryChangeProtocol.Acceptor(req.replyToParty)
) }
}
private fun handleChangeNotaryRequest(req: AbstractStateReplacementProtocol.Handshake): Ack {
val protocol = NotaryChangeProtocol.Acceptor(
req.replyToParty,
req.sessionID,
req.sessionIdForSend)
services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
return Ack
} }
} }
} }

View File

@ -1,10 +1,14 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message import com.r3corda.core.messaging.Message
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
@ -14,6 +18,10 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe @ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() { abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService val net: MessagingServiceInternal get() = services.networkService
/** /**
@ -57,4 +65,37 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
crossinline handler: (Q) -> R) { crossinline handler: (Q) -> R) {
addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception }) addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
} }
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: (ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
} }

View File

@ -1,10 +1,11 @@
package com.r3corda.node.services.clientapi package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
/** /**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of * This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
@ -17,12 +18,10 @@ object FixingSessionInitiation {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) { class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init { init {
services.networkService.addMessageHandler(TwoPartyDealProtocol.FIX_INITIATE_TOPIC, DEFAULT_SESSION_ID) { msg, registration -> addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
val initiation = msg.data.deserialize<TwoPartyDealProtocol.FixingSessionInitiation>() TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
val protocol = TwoPartyDealProtocol.Fixer(initiation)
services.startProtocol("fixings", protocol)
} }
} }
} }

View File

@ -6,7 +6,6 @@ import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.success import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
@ -50,8 +49,7 @@ object DataVending {
myIdentity: Party, myIdentity: Party,
recipient: NodeInfo, recipient: NodeInfo,
transaction: SignedTransaction) { transaction: SignedTransaction) {
val sessionID = random63BitValue() val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity, sessionID)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address) net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
} }
} }
@ -65,29 +63,29 @@ object DataVending {
{ req: FetchDataProtocol.Request -> handleTXRequest(req) }, { req: FetchDataProtocol.Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) } { message, e -> logger.error("Failure processing data vending request.", e) }
) )
addMessageHandler(FetchAttachmentsProtocol.TOPIC, addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) }, { req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) } { message, e -> logger.error("Failure processing data vending request.", e) }
) )
addMessageHandler(BroadcastTransactionProtocol.TOPIC,
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage -> handleTXNotification(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
}
private fun handleTXNotification(req: BroadcastTransactionProtocol.NotifyTxRequestMessage): Unit {
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction // TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all. // includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming // TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties? // cash without from unknown parties?
addProtocolHandler(
services.startProtocol("Resolving transactions", ResolveTransactionsProtocol(req.tx, req.replyToParty)) BroadcastTransactionProtocol.TOPIC,
.success { "Resolving transactions",
services.recordTransactions(req.tx) { req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
}.failure { throwable -> ResolveTransactionsProtocol(req.tx, req.replyToParty)
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable) },
} { future, req ->
future.success {
services.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
} }
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> { private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {

View File

@ -1,12 +1,12 @@
package com.r3corda.node.services.transactions package com.r3corda.node.services.transactions
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC
/** /**
* A Notary service acts as the final signer of a transaction ensuring two things: * A Notary service acts as the final signer of a transaction ensuring two things:
@ -30,19 +30,9 @@ abstract class NotaryService(services: ServiceHubInternal,
abstract val protocolFactory: NotaryProtocol.Factory abstract val protocolFactory: NotaryProtocol.Factory
init { init {
addMessageHandler(NotaryProtocol.TOPIC, addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
{ req: NotaryProtocol.Handshake -> processRequest(req) } protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
) }
} }
private fun processRequest(req: NotaryProtocol.Handshake): Ack {
val protocol = protocolFactory.create(
req.replyToParty,
req.sessionID,
req.sendSessionID,
timestampChecker,
uniquenessProvider)
services.startProtocol(NotaryProtocol.TOPIC, protocol)
return Ack
}
} }

View File

@ -19,11 +19,9 @@ class ValidatingNotaryService(services: ServiceHubInternal,
override val protocolFactory = object : NotaryProtocol.Factory { override val protocolFactory = object : NotaryProtocol.Factory {
override fun create(otherSide: Party, override fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker, timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service { uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, sendSessionID, receiveSessionID, timestampChecker, uniquenessProvider) return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
} }
} }
} }

View File

@ -1,6 +1,5 @@
package com.r3corda.node.messaging package com.r3corda.node.messaging
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.CommercialPaper import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.asset.* import com.r3corda.contracts.asset.*
import com.r3corda.contracts.testing.fillWithSomeTestCash import com.r3corda.contracts.testing.fillWithSomeTestCash
@ -9,27 +8,26 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TransactionStorage import com.r3corda.core.node.services.TransactionStorage
import com.r3corda.core.node.services.Wallet import com.r3corda.core.node.services.Wallet
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.* import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -42,6 +40,7 @@ import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.util.* import java.util.*
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.jar.JarOutputStream import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry import java.util.zip.ZipEntry
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -56,21 +55,11 @@ import kotlin.test.assertTrue
* We assume that Alice and Bob already found each other via some market, and have agreed the details already. * We assume that Alice and Bob already found each other via some market, and have agreed the details already.
*/ */
class TwoPartyTradeProtocolTests { class TwoPartyTradeProtocolTests {
lateinit var net: MockNetwork lateinit var net: MockNetwork
lateinit var notaryNode: MockNetwork.MockNode
private fun runSeller(smm: StateMachineManager, notary: NodeInfo, lateinit var aliceNode: MockNetwork.MockNode
otherSide: Party, assetToSell: StateAndRef<OwnableState>, price: Amount<Currency>, lateinit var bobNode: MockNetwork.MockNode
myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture<SignedTransaction> {
val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, assetToSell, price, myKeyPair, buyerSessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.seller", seller).resultFuture
}
private fun runBuyer(smm: StateMachineManager, notaryNode: NodeInfo,
otherSide: Party, acceptablePrice: Amount<Currency>, typeToBuy: Class<out OwnableState>,
sessionID: Long): ListenableFuture<SignedTransaction> {
val buyer = TwoPartyTradeProtocol.Buyer(otherSide, notaryNode.identity, acceptablePrice, typeToBuy, sessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.buyer", buyer).resultFuture
}
@Before @Before
fun before() { fun before() {
@ -92,10 +81,9 @@ class TwoPartyTradeProtocolTests {
net = MockNetwork(false, true) net = MockNetwork(false, true)
ledger { ledger {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
bobNode.services.fillWithSomeTestCash(2000.DOLLARS) bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
@ -103,26 +91,7 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue() val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// We start the Buyer first, as the Seller sends the first message
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val aliceResult = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
// TODO: Verify that the result was inserted into the transaction database. // TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id]) // assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
@ -139,9 +108,9 @@ class TwoPartyTradeProtocolTests {
@Test @Test
fun `shutdown and restore`() { fun `shutdown and restore`() {
ledger { ledger {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
var bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.Handle val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.Handle
val networkMapAddr = notaryNode.info.address val networkMapAddr = notaryNode.info.address
@ -153,25 +122,7 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue() val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).second
val aliceFuture = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
// Everything is on this thread so we can now step through the protocol one step at a time. // Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once: // Seller Alice already sent a message to Buyer Bob. Pump once:
@ -210,7 +161,7 @@ class TwoPartyTradeProtocolTests {
}, true, BOB.name, BOB_KEY) }, true, BOB.name, BOB_KEY)
// Find the future representing the result of this state machine again. // Find the future representing the result of this state machine again.
val bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second val bobFuture = bobNode.smm.findStateMachines(Buyer::class.java).single().second
// And off we go again. // And off we go again.
net.runNetwork() net.runNetwork()
@ -218,7 +169,7 @@ class TwoPartyTradeProtocolTests {
// Bob is now finished and has the same transaction as Alice. // Bob is now finished and has the same transaction as Alice.
assertThat(bobFuture.get()).isEqualTo(aliceFuture.get()) assertThat(bobFuture.get()).isEqualTo(aliceFuture.get())
assertThat(bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java)).isEmpty() assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
@ -250,9 +201,9 @@ class TwoPartyTradeProtocolTests {
@Test @Test
fun `check dependencies of sale asset are resolved`() { fun `check dependencies of sale asset are resolved`() {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
ledger(aliceNode.services) { ledger(aliceNode.services) {
@ -271,27 +222,9 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages net.runNetwork() // Clear network map registration messages
runSeller( runBuyerAndSeller("alice's paper".outputStateAndRef())
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
net.runNetwork() net.runNetwork()
@ -370,14 +303,25 @@ class TwoPartyTradeProtocolTests {
} }
} }
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : Pair<Future<SignedTransaction>, Future<SignedTransaction>> {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller)
// We start the Buyer first, as the Seller sends the first message
val buyerResult = bobNode.smm.add("$TOPIC.buyer", buyer).resultFuture
val sellerResult = aliceNode.smm.add("$TOPIC.seller", seller).resultFuture
return Pair(buyerResult, sellerResult)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError( private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
bobError: Boolean, bobError: Boolean,
aliceError: Boolean, aliceError: Boolean,
expectedMessageSubstring: String expectedMessageSubstring: String
) { ) {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val issuer = MEGA_CORP.ref(1, 2, 3) val issuer = MEGA_CORP.ref(1, 2, 3)
val bobKey = bobNode.keyManagement.freshKey() val bobKey = bobNode.keyManagement.freshKey()
@ -388,27 +332,9 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(bobsBadCash, bobNode.services, bobNode.storage.myLegalIdentityKey, bobNode.storage.myLegalIdentityKey) insertFakeTransactions(bobsBadCash, bobNode.services, bobNode.storage.myLegalIdentityKey, bobNode.storage.myLegalIdentityKey)
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages net.runNetwork() // Clear network map registration messages
val aliceResult = runSeller( val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
net.runNetwork() net.runNetwork()

View File

@ -1,18 +1,30 @@
package com.r3corda.node.services package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.node.NodeInfo import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.map
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NetworkMapService.*
import com.r3corda.node.services.network.NetworkMapService.Companion.FETCH_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.PUSH_ACK_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.security.PrivateKey import java.security.PrivateKey
import java.time.Instant import java.time.Instant
import java.util.concurrent.Future
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotNull import kotlin.test.assertNotNull
import kotlin.test.assertNull import kotlin.test.assertNull
@ -36,7 +48,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm the service contains only its own node // Confirm the service contains only its own node
assertEquals(1, service.nodes.count()) assertEquals(1, service.nodes.count())
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Register the second node // Register the second node
var seq = 1L var seq = 1L
@ -44,64 +56,22 @@ class InMemoryNetworkMapServiceTest {
val nodeKey = registerNode.storage.myLegalIdentityKey val nodeKey = registerNode.storage.myLegalIdentityKey
val addChange = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) val addChange = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires)
val addWireChange = addChange.toWire(nodeKey.private) val addWireChange = addChange.toWire(nodeKey.private)
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count()) assertEquals(2, service.nodes.count())
assertEquals(mapServiceNode.info, service.processQueryRequest(NetworkMapService.QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) assertEquals(mapServiceNode.info, service.processQueryRequest(QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Re-registering should be a no-op // Re-registering should be a no-op
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count()) assertEquals(2, service.nodes.count())
// Confirm that de-registering the node succeeds and drops it from the node lists // Confirm that de-registering the node succeeds and drops it from the node lists
val removeChange = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires) val removeChange = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val removeWireChange = removeChange.toWire(nodeKey.private) val removeWireChange = removeChange.toWire(nodeKey.private)
assert(service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) assert(service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Trying to de-register a node that doesn't exist should fail // Trying to de-register a node that doesn't exist should fail
assert(!service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) assert(!service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
}
class TestAcknowledgePSM(val server: NodeInfo, val mapVersion: Int) : ProtocolLogic<Unit>() {
override val topic: String get() = NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC
@Suspendable
override fun call() {
val req = NetworkMapService.UpdateAcknowledge(mapVersion, serviceHub.networkService.myAddress)
send(server.identity, 0, req)
}
}
class TestFetchPSM(val server: NodeInfo, val subscribe: Boolean, val ifChangedSinceVersion: Int? = null)
: ProtocolLogic<Collection<NodeRegistration>?>() {
override val topic: String get() = NetworkMapService.FETCH_PROTOCOL_TOPIC
@Suspendable
override fun call(): Collection<NodeRegistration>? {
val sessionID = random63BitValue()
val req = NetworkMapService.FetchMapRequest(subscribe, ifChangedSinceVersion, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.FetchMapResponse>(server.identity, 0, sessionID, req).unwrap { it.nodes }
}
}
class TestRegisterPSM(val server: NodeInfo, val reg: NodeRegistration, val privateKey: PrivateKey)
: ProtocolLogic<NetworkMapService.RegistrationResponse>() {
override val topic: String get() = NetworkMapService.REGISTER_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.RegistrationResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.RegistrationRequest(reg.toWire(privateKey), serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.RegistrationResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
}
class TestSubscribePSM(val server: NodeInfo, val subscribe: Boolean)
: ProtocolLogic<NetworkMapService.SubscribeResponse>() {
override val topic: String get() = NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.SubscribeResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.SubscribeRequest(subscribe, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.SubscribeResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
} }
@Test @Test
@ -113,7 +83,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm all nodes have registered themselves // Confirm all nodes have registered themselves
network.runNetwork() network.runNetwork()
var fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false)) var fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork() network.runNetwork()
assertEquals(2, fetchPsm.get()?.count()) assertEquals(2, fetchPsm.get()?.count())
@ -122,12 +92,12 @@ class InMemoryNetworkMapServiceTest {
val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD
val seq = 2L val seq = 2L
val reg = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires) val reg = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val registerPsm = registerNode.services.startProtocol(NetworkMapService.REGISTER_PROTOCOL_TOPIC, TestRegisterPSM(mapServiceNode.info, reg, nodeKey.private)) val registerPsm = registration(registerNode, mapServiceNode, reg, nodeKey.private)
network.runNetwork() network.runNetwork()
assertTrue(registerPsm.get().success) assertTrue(registerPsm.get().success)
// Now only map service node should be registered // Now only map service node should be registered
fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false)) fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork() network.runNetwork()
assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single()) assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single())
} }
@ -139,8 +109,7 @@ class InMemoryNetworkMapServiceTest {
// Test subscribing to updates // Test subscribing to updates
network.runNetwork() network.runNetwork()
val subscribePsm = registerNode.services.startProtocol(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, val subscribePsm = subscribe(registerNode, mapServiceNode, true)
TestSubscribePSM(mapServiceNode.info, true))
network.runNetwork() network.runNetwork()
subscribePsm.get() subscribePsm.get()
@ -161,10 +130,8 @@ class InMemoryNetworkMapServiceTest {
assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
// Send in an acknowledgment and verify the count goes down // Send in an acknowledgment and verify the count goes down
val acknowledgePsm = registerNode.services.startProtocol(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, updateAcknowlege(registerNode, mapServiceNode, startingMapVersion + 1)
TestAcknowledgePSM(mapServiceNode.info, startingMapVersion + 1))
network.runNetwork() network.runNetwork()
acknowledgePsm.get()
assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
@ -181,4 +148,25 @@ class InMemoryNetworkMapServiceTest {
} }
} }
} }
}
private fun registration(registerNode: MockNode, mapServiceNode: MockNode, reg: NodeRegistration, privateKey: PrivateKey): ListenableFuture<RegistrationResponse> {
val req = RegistrationRequest(reg.toWire(privateKey), registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<RegistrationResponse>(REGISTER_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun subscribe(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean): ListenableFuture<SubscribeResponse> {
val req = SubscribeRequest(subscribe, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<SubscribeResponse>(SUBSCRIPTION_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun updateAcknowlege(registerNode: MockNode, mapServiceNode: MockNode, mapVersion: Int) {
val req = UpdateAcknowledge(mapVersion, registerNode.services.networkService.myAddress)
registerNode.send(PUSH_ACK_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun fetchMap(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean, ifChangedSinceVersion: Int? = null): Future<Collection<NodeRegistration>?> {
val req = FetchMapRequest(subscribe, ifChangedSinceVersion, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<FetchMapResponse>(FETCH_PROTOCOL_TOPIC, mapServiceNode, req).map { it.nodes }
}
}

View File

@ -1,33 +1,36 @@
package com.r3corda.node.services package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.asset.Cash import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.newSecureRandom import com.r3corda.core.crypto.newSecureRandom
import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.Wallet import com.r3corda.core.node.services.Wallet
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.OpaqueBytes import com.r3corda.core.serialization.OpaqueBytes
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.DUMMY_PUBKEY_1 import com.r3corda.core.utilities.DUMMY_PUBKEY_1
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.monitor.* import com.r3corda.node.services.monitor.*
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.IN_EVENT_TOPIC
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.REGISTER_TOPIC
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.testing.* import com.r3corda.testing.expect
import com.r3corda.testing.expectEvents
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import com.r3corda.testing.parallel
import com.r3corda.testing.sequence
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.subjects.PublishSubject
import rx.subjects.ReplaySubject import rx.subjects.ReplaySubject
import java.util.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.test.* import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
/** /**
* Unit tests for the wallet monitoring service. * Unit tests for the wallet monitoring service.
@ -43,32 +46,13 @@ class WalletMonitorServiceTests {
/** /**
* Authenticate the register node with the monitor service node. * Authenticate the register node with the monitor service node.
*/ */
private fun authenticate(monitorServiceNode: MockNetwork.MockNode, registerNode: MockNetwork.MockNode): Long { private fun authenticate(monitorServiceNode: MockNode, registerNode: MockNode): Long {
network.runNetwork() network.runNetwork()
val sessionID = random63BitValue() val sessionId = random63BitValue()
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC, val authenticatePsm = register(registerNode, monitorServiceNode, sessionId)
TestRegisterPSM(monitorServiceNode.info, sessionID))
network.runNetwork() network.runNetwork()
authenticatePsm.get(1, TimeUnit.SECONDS) authenticatePsm.get(1, TimeUnit.SECONDS)
return sessionID return sessionId
}
class TestReceiveWalletUpdatePSM(val sessionID: Long)
: ProtocolLogic<ServiceToClientEvent.OutputState>() {
override val topic: String get() = WalletMonitorService.IN_EVENT_TOPIC
@Suspendable
override fun call(): ServiceToClientEvent.OutputState
= receive<ServiceToClientEvent.OutputState>(sessionID).unwrap { it }
}
class TestRegisterPSM(val server: NodeInfo, val sessionID: Long)
: ProtocolLogic<RegisterResponse>() {
override val topic: String get() = WalletMonitorService.REGISTER_TOPIC
@Suspendable
override fun call(): RegisterResponse {
val req = RegisterRequest(serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<RegisterResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
} }
/** /**
@ -79,9 +63,7 @@ class WalletMonitorServiceTests {
val (monitorServiceNode, registerNode) = network.createTwoNodes() val (monitorServiceNode, registerNode) = network.createTwoNodes()
network.runNetwork() network.runNetwork()
val sessionID = random63BitValue() val authenticatePsm = register(registerNode, monitorServiceNode, random63BitValue())
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC,
TestRegisterPSM(monitorServiceNode.info, sessionID))
network.runNetwork() network.runNetwork()
val result = authenticatePsm.get(1, TimeUnit.SECONDS) val result = authenticatePsm.get(1, TimeUnit.SECONDS)
assertTrue(result.success) assertTrue(result.success)
@ -94,8 +76,7 @@ class WalletMonitorServiceTests {
fun `event received`() { fun `event received`() {
val (monitorServiceNode, registerNode) = network.createTwoNodes() val (monitorServiceNode, registerNode) = network.createTwoNodes()
val sessionID = authenticate(monitorServiceNode, registerNode) val sessionID = authenticate(monitorServiceNode, registerNode)
var receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC, var receivePsm = receiveWalletUpdate(registerNode, sessionID)
TestReceiveWalletUpdatePSM(sessionID))
var expected = Wallet.Update(emptySet(), emptySet()) var expected = Wallet.Update(emptySet(), emptySet())
monitorServiceNode.inNodeWalletMonitorService!!.notifyWalletUpdate(expected) monitorServiceNode.inNodeWalletMonitorService!!.notifyWalletUpdate(expected)
network.runNetwork() network.runNetwork()
@ -104,8 +85,7 @@ class WalletMonitorServiceTests {
assertEquals(expected.produced, actual.produced) assertEquals(expected.produced, actual.produced)
// Check that states are passed through correctly // Check that states are passed through correctly
receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC, receivePsm = receiveWalletUpdate(registerNode, sessionID)
TestReceiveWalletUpdatePSM(sessionID))
val consumed = setOf(StateRef(SecureHash.randomSHA256(), 0)) val consumed = setOf(StateRef(SecureHash.randomSHA256(), 0))
val producedState = TransactionState(DummyContract.SingleOwnerState(newSecureRandom().nextInt(), DUMMY_PUBKEY_1), DUMMY_NOTARY) val producedState = TransactionState(DummyContract.SingleOwnerState(newSecureRandom().nextInt(), DUMMY_PUBKEY_1), DUMMY_NOTARY)
val produced = setOf(StateAndRef(producedState, StateRef(SecureHash.randomSHA256(), 0))) val produced = setOf(StateAndRef(producedState, StateRef(SecureHash.randomSHA256(), 0)))
@ -125,7 +105,7 @@ class WalletMonitorServiceTests {
val events = ReplaySubject.create<ServiceToClientEvent>() val events = ReplaySubject.create<ServiceToClientEvent>()
val ref = OpaqueBytes(ByteArray(1) {1}) val ref = OpaqueBytes(ByteArray(1) {1})
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg -> registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>()) events.onNext(msg.data.deserialize<ServiceToClientEvent>())
} }
@ -178,7 +158,7 @@ class WalletMonitorServiceTests {
val quantity = 1000L val quantity = 1000L
val events = ReplaySubject.create<ServiceToClientEvent>() val events = ReplaySubject.create<ServiceToClientEvent>()
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg -> registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>()) events.onNext(msg.data.deserialize<ServiceToClientEvent>())
} }
@ -240,4 +220,14 @@ class WalletMonitorServiceTests {
) )
} }
} }
private fun register(registerNode: MockNode, monitorServiceNode: MockNode, sessionId: Long): ListenableFuture<RegisterResponse> {
val req = RegisterRequest(registerNode.services.networkService.myAddress, sessionId)
return registerNode.sendAndReceive<RegisterResponse>(REGISTER_TOPIC, monitorServiceNode, req)
}
private fun receiveWalletUpdate(registerNode: MockNode, sessionId: Long): ListenableFuture<ServiceToClientEvent.OutputState> {
return registerNode.receive<ServiceToClientEvent.OutputState>(IN_EVENT_TOPIC, sessionId)
}
} }

View File

@ -6,6 +6,7 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@ -49,10 +50,12 @@ class StateMachineManagerTests {
@Test @Test
fun `protocol suspended just after receiving payload`() { fun `protocol suspended just after receiving payload`() {
val topic = "send-and-receive" val topic = "send-and-receive"
val sessionID = random63BitValue()
val payload = random63BitValue() val payload = random63BitValue()
node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload)) val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
node2.smm.add("test", ReceiveProtocol(topic, sessionID)) val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.smm.add("test", sendProtocol)
node2.smm.add("test", receiveProtocol)
net.runNetwork() net.runNetwork()
node2.stop() node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
@ -90,19 +93,19 @@ class StateMachineManagerTests {
} }
private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic<Unit>() { private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable @Suspendable
override fun call() = send(destination, sessionID, payload) override fun call() = send(otherParty, payload)
} }
private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() { private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null @Transient var receivedPayload: Any? = null
@Suspendable @Suspendable
override fun doCall() { override fun doCall() {
receivedPayload = receive<Any>(sessionID).unwrap { it } receivedPayload = receive<Any>(otherParty).unwrap { it }
} }
} }

View File

@ -13,27 +13,26 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.Emoji import com.r3corda.core.utilities.Emoji
import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.internal.Node import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.config.NodeConfigurationFromConfig import com.r3corda.node.services.config.NodeConfigurationFromConfig
import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol import com.r3corda.protocols.TwoPartyTradeProtocol
import joptsimple.OptionParser import joptsimple.OptionParser
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import java.security.PublicKey import java.security.PublicKey
@ -208,16 +207,21 @@ private fun runBuyer(node: Node, amount: Amount<Currency>) {
// next stage in our building site, we will just auto-generate fake trades to give our nodes something to do. // next stage in our building site, we will just auto-generate fake trades to give our nodes something to do.
// //
// As the seller initiates the two-party trade protocol, here, we will be the buyer. // As the seller initiates the two-party trade protocol, here, we will be the buyer.
node.services.networkService.addMessageHandler(DEMO_TOPIC, DEFAULT_SESSION_ID) { message, registration -> object : AbstractNodeService(node.services) {
// We use a simple scenario-specific wrapper protocol to make things happen. init {
val otherSide = message.data.deserialize<Party>() addProtocolHandler(DEMO_TOPIC, "demo.buyer") { handshake: TraderDemoHandshake ->
val buyer = TraderDemoProtocolBuyer(otherSide, attachmentsPath, amount) TraderDemoProtocolBuyer(handshake.replyToParty, attachmentsPath, amount)
node.services.startProtocol("demo.buyer", buyer) }
}
} }
} }
// We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic. // We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic.
private data class TraderDemoHandshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
private val DEMO_TOPIC = "initiate.demo.trade" private val DEMO_TOPIC = "initiate.demo.trade"
private class TraderDemoProtocolBuyer(val otherSide: Party, private class TraderDemoProtocolBuyer(val otherSide: Party,
@ -231,21 +235,17 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
@Suspendable @Suspendable
override fun call() { override fun call() {
// The session ID disambiguates the test trade.
val sessionID = random63BitValue()
progressTracker.currentStep = STARTING_BUY progressTracker.currentStep = STARTING_BUY
send(otherSide, 0, sessionID)
val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0] val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0]
val buyer = TwoPartyTradeProtocol.Buyer( val buyer = TwoPartyTradeProtocol.Buyer(
otherSide, otherSide,
notary.identity, notary.identity,
amount, amount,
CommercialPaper.State::class.java, CommercialPaper.State::class.java)
sessionID)
// This invokes the trading protocol and out pops our finished transaction. // This invokes the trading protocol and out pops our finished transaction.
val tradeTX: SignedTransaction = subProtocol(buyer) val tradeTX: SignedTransaction = subProtocol(buyer, inheritParentSessions = true)
// TODO: This should be moved into the protocol itself. // TODO: This should be moved into the protocol itself.
serviceHub.recordTransactions(listOf(tradeTX)) serviceHub.recordTransactions(listOf(tradeTX))
@ -306,7 +306,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
progressTracker.currentStep = ANNOUNCING progressTracker.currentStep = ANNOUNCING
val sessionID = sendAndReceive<Long>(otherSide, 0, 0, serviceHub.storageService.myLegalIdentity).unwrap { it } send(otherSide, TraderDemoHandshake(serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = SELF_ISSUING progressTracker.currentStep = SELF_ISSUING
@ -316,9 +316,14 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
progressTracker.currentStep = TRADING progressTracker.currentStep = TRADING
val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, commercialPaper, amount, cpOwnerKey, val seller = TwoPartyTradeProtocol.Seller(
sessionID, progressTracker.getChildProgressTracker(TRADING)!!) otherSide,
val tradeTX: SignedTransaction = subProtocol(seller) notary,
commercialPaper,
amount,
cpOwnerKey,
progressTracker.getChildProgressTracker(TRADING)!!)
val tradeTX: SignedTransaction = subProtocol(seller, inheritParentSessions = true)
serviceHub.recordTransactions(listOf(tradeTX)) serviceHub.recordTransactions(listOf(tradeTX))
return tradeTX return tradeTX

View File

@ -75,7 +75,10 @@ object NodeInterestRates {
* Interest rates become available when they are uploaded via the web as per [DataUploadServlet], * Interest rates become available when they are uploaded via the web as per [DataUploadServlet],
* if they haven't already been uploaded that way. * if they haven't already been uploaded that way.
*/ */
services.startProtocol("fixing", FixQueryHandler(this, req as RatesFixProtocol.QueryRequest)) req as RatesFixProtocol.QueryRequest
val handler = FixQueryHandler(this, req)
handler.registerSession(req)
services.startProtocol("fixing", handler)
Unit Unit
} }
}, },
@ -102,7 +105,7 @@ object NodeInterestRates {
override fun call(): Unit { override fun call(): Unit {
val answers = service.oracle.query(request.queries, request.deadline) val answers = service.oracle.query(request.queries, request.deadline)
progressTracker.currentStep = SENDING progressTracker.currentStep = SENDING
send(request.replyToParty, request.sessionID, answers) send(request.replyToParty, answers)
} }
} }

View File

@ -2,18 +2,19 @@ package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
import com.r3corda.core.contracts.DealState import com.r3corda.core.contracts.DealState
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.protocols.TwoPartyDealProtocol.DEAL_TOPIC
/** /**
* This whole class is really part of a demo just to initiate the agreement of a deal with a simple * This whole class is really part of a demo just to initiate the agreement of a deal with a simple
@ -25,23 +26,24 @@ import com.r3corda.protocols.TwoPartyDealProtocol
object AutoOfferProtocol { object AutoOfferProtocol {
val TOPIC = "autooffer.topic" val TOPIC = "autooffer.topic"
data class AutoOfferMessage(val otherSide: Party, data class AutoOfferMessage(val notary: Party,
val notary: Party, val dealBeingOffered: DealState,
val otherSessionID: Long, val dealBeingOffered: DealState) override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) { class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
object RECEIVED : ProgressTracker.Step("Received offer")
object DEALING : ProgressTracker.Step("Starting the deal protocol") { object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker() override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker()
} }
fun tracker() = ProgressTracker(RECEIVED, DEALING) fun tracker() = ProgressTracker(DEALING)
class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback<SignedTransaction> { class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback<SignedTransaction> {
override fun onFailure(t: Throwable?) { override fun onFailure(t: Throwable?) {
@ -54,20 +56,17 @@ object AutoOfferProtocol {
} }
init { init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration -> addProtocolHandler(TOPIC, "$DEAL_TOPIC.seller") { autoOfferMessage: AutoOfferMessage ->
val progressTracker = tracker() val progressTracker = tracker()
progressTracker.currentStep = RECEIVED
val autoOfferMessage = msg.data.deserialize<AutoOfferMessage>()
// Put the deal onto the ledger // Put the deal onto the ledger
progressTracker.currentStep = DEALING progressTracker.currentStep = DEALING
val seller = TwoPartyDealProtocol.Instigator(autoOfferMessage.otherSide, autoOfferMessage.notary, TwoPartyDealProtocol.Instigator(
autoOfferMessage.dealBeingOffered, services.keyManagementService.freshKey(), autoOfferMessage.otherSessionID, progressTracker.getChildProgressTracker(DEALING)!!) autoOfferMessage.replyToParty,
val future = services.startProtocol("${TwoPartyDealProtocol.DEAL_TOPIC}.seller", seller) autoOfferMessage.notary,
// This is required because we are doing child progress outside of a subprotocol. In future, we should just wrap things like this in a protocol to avoid it autoOfferMessage.dealBeingOffered,
Futures.addCallback(future, Callback() { services.keyManagementService.freshKey(),
seller.progressTracker.currentStep = ProgressTracker.DONE progressTracker.getChildProgressTracker(DEALING)!!
progressTracker.currentStep = ProgressTracker.DONE )
})
} }
} }
@ -98,15 +97,15 @@ object AutoOfferProtocol {
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
require(serviceHub.networkMapCache.notaryNodes.isNotEmpty()) { "No notary nodes registered" } require(serviceHub.networkMapCache.notaryNodes.isNotEmpty()) { "No notary nodes registered" }
val ourSessionID = random63BitValue()
val notary = serviceHub.networkMapCache.notaryNodes.first().identity val notary = serviceHub.networkMapCache.notaryNodes.first().identity
// need to pick which ever party is not us // need to pick which ever party is not us
val otherParty = notUs(dealToBeOffered.parties).single() val otherParty = notUs(dealToBeOffered.parties).single()
progressTracker.currentStep = ANNOUNCING progressTracker.currentStep = ANNOUNCING
send(otherParty, 0, AutoOfferMessage(serviceHub.storageService.myLegalIdentity, notary, ourSessionID, dealToBeOffered)) send(otherParty, AutoOfferMessage(notary, dealToBeOffered, serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = DEALING progressTracker.currentStep = DEALING
val stx = subProtocol(TwoPartyDealProtocol.Acceptor(otherParty, notary, dealToBeOffered, ourSessionID, progressTracker.getChildProgressTracker(DEALING)!!)) val stx = subProtocol(
Acceptor(otherParty, notary, dealToBeOffered, progressTracker.getChildProgressTracker(DEALING)!!),
inheritParentSessions = true)
return stx return stx
} }

View File

@ -2,12 +2,15 @@ package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache import com.r3corda.testing.node.MockNetworkMapCache
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -19,7 +22,12 @@ object ExitServerProtocol {
// Will only be enabled if you install the Handler // Will only be enabled if you install the Handler
@Volatile private var enabled = false @Volatile private var enabled = false
data class ExitMessage(val exitCode: Int) // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class ExitMessage(val exitCode: Int,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
@ -50,10 +58,8 @@ object ExitServerProtocol {
@Suspendable @Suspendable
override fun call(): Boolean { override fun call(): Boolean {
if (enabled) { if (enabled) {
val message = ExitMessage(exitCode)
for (recipient in serviceHub.networkMapCache.partyNodes) { for (recipient in serviceHub.networkMapCache.partyNodes) {
doNextRecipient(recipient, message) doNextRecipient(recipient)
} }
// Sleep a little in case any async message delivery to other nodes needs to happen // Sleep a little in case any async message delivery to other nodes needs to happen
Strand.sleep(1, TimeUnit.SECONDS) Strand.sleep(1, TimeUnit.SECONDS)
@ -63,11 +69,11 @@ object ExitServerProtocol {
} }
@Suspendable @Suspendable
private fun doNextRecipient(recipient: NodeInfo, message: ExitMessage) { private fun doNextRecipient(recipient: NodeInfo) {
if (recipient.address is MockNetworkMapCache.MockAddress) { if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore // Ignore
} else { } else {
send(recipient.identity, 0, message) send(recipient.identity, ExitMessage(exitCode, recipient.identity))
} }
} }
} }

View File

@ -1,15 +1,17 @@
package com.r3corda.demos.protocols package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.demos.DemoClock import com.r3corda.demos.DemoClock
import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache import com.r3corda.testing.node.MockNetworkMapCache
import java.time.LocalDate import java.time.LocalDate
@ -20,7 +22,12 @@ object UpdateBusinessDayProtocol {
val TOPIC = "businessday.topic" val TOPIC = "businessday.topic"
data class UpdateBusinessDayMessage(val date: LocalDate) // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class UpdateBusinessDayMessage(val date: LocalDate,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
@ -50,18 +57,17 @@ object UpdateBusinessDayProtocol {
@Suspendable @Suspendable
override fun call(): Unit { override fun call(): Unit {
progressTracker.currentStep = NOTIFYING progressTracker.currentStep = NOTIFYING
val message = UpdateBusinessDayMessage(date)
for (recipient in serviceHub.networkMapCache.partyNodes) { for (recipient in serviceHub.networkMapCache.partyNodes) {
doNextRecipient(recipient, message) doNextRecipient(recipient)
} }
} }
@Suspendable @Suspendable
private fun doNextRecipient(recipient: NodeInfo, message: UpdateBusinessDayMessage) { private fun doNextRecipient(recipient: NodeInfo) {
if (recipient.address is MockNetworkMapCache.MockAddress) { if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore // Ignore
} else { } else {
send(recipient.identity, 0, message) send(recipient.identity, UpdateBusinessDayMessage(date, recipient.identity))
} }
} }
} }

View File

@ -7,14 +7,14 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.r3corda.contracts.InterestRateSwap import com.r3corda.contracts.InterestRateSwap
import com.r3corda.core.RunOnCallerThread import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.StateAndRef import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.contracts.UniqueIdentifier import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.failure import com.r3corda.core.failure
import com.r3corda.core.node.services.linearHeadsOfType import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.random63BitValue
import com.r3corda.core.success import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockIdentityService import com.r3corda.testing.node.MockIdentityService
import java.security.KeyPair import java.security.KeyPair
@ -121,10 +121,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
irs.fixedLeg.fixedRatePayer = node1.info.identity irs.fixedLeg.fixedRatePayer = node1.info.identity
irs.floatingLeg.floatingRatePayer = node2.info.identity irs.floatingLeg.floatingRatePayer = node2.info.identity
val sessionID = random63BitValue() val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, nodeAKey!!)
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs)
val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, nodeAKey!!, sessionID) connectProtocols(instigator, acceptor)
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs, sessionID)
showProgressFor(listOf(node1, node2)) showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0])) showConsensusFor(listOf(node1, node2, regulators[0]))

View File

@ -7,12 +7,13 @@ import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER
import com.r3corda.contracts.testing.fillWithSomeTestCash import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.contracts.DOLLARS import com.r3corda.core.contracts.DOLLARS
import com.r3corda.core.contracts.OwnableState import com.r3corda.core.contracts.OwnableState
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.`issued by` import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol import com.r3corda.protocols.TwoPartyTradeProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import java.time.Instant import java.time.Instant
@ -43,26 +44,24 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
seller.services.recordTransactions(issuance) seller.services.recordTransactions(issuance)
val amount = 1000.DOLLARS val amount = 1000.DOLLARS
val sessionID = random63BitValue()
val buyerProtocol = TwoPartyTradeProtocol.Buyer( val buyerProtocol = TwoPartyTradeProtocol.Buyer(
seller.info.identity, seller.info.identity,
notary.info.identity, notary.info.identity,
amount, amount,
CommercialPaper.State::class.java, CommercialPaper.State::class.java)
sessionID)
val sellerProtocol = TwoPartyTradeProtocol.Seller( val sellerProtocol = TwoPartyTradeProtocol.Seller(
buyer.info.identity, buyer.info.identity,
notary.info, notary.info,
issuance.tx.outRef<OwnableState>(0), issuance.tx.outRef<OwnableState>(0),
amount, amount,
seller.storage.myLegalIdentityKey, seller.storage.myLegalIdentityKey)
sessionID) connectProtocols(buyerProtocol, sellerProtocol)
showConsensusFor(listOf(buyer, seller, notary)) showConsensusFor(listOf(buyer, seller, notary))
showProgressFor(listOf(buyer, seller)) showProgressFor(listOf(buyer, seller))
val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.${TwoPartyTradeProtocol.TOPIC}.buyer", buyerProtocol) val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.$TOPIC.buyer", buyerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.${TwoPartyTradeProtocol.TOPIC}.seller", sellerProtocol) val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.$TOPIC.seller", sellerProtocol)
return Futures.successfulAsList(buyerFuture, sellerFuture) return Futures.successfulAsList(buyerFuture, sellerFuture)
} }

View File

@ -4,15 +4,19 @@ package com.r3corda.testing
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.testing.*
import com.r3corda.core.contracts.StateRef import com.r3corda.core.contracts.StateRef
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.* import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.testing.node.MockIdentityService import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.testing.node.MockServices import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockIdentityService
import com.r3corda.testing.node.MockServices
import java.net.ServerSocket import java.net.ServerSocket
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
@ -124,3 +128,23 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
transactionBuilder: TransactionBuilder = TransactionBuilder(notary = DUMMY_NOTARY), transactionBuilder: TransactionBuilder = TransactionBuilder(notary = DUMMY_NOTARY),
dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail
) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) } ) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) }
/**
* Connect two protocols together for communication. Both protocols must have a property called otherParty of type Party
* which points to the other party in the communication.
*/
fun connectProtocols(protocol1: ProtocolLogic<*>, protocol2: ProtocolLogic<*>) {
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long,
override val receiveSessionID: Long) : HandshakeMessage
val sessionId1 = random63BitValue()
val sessionId2 = random63BitValue()
protocol1.registerSession(Handshake(protocol1.otherParty, sessionId1, sessionId2))
protocol2.registerSession(Handshake(protocol2.otherParty, sessionId2, sessionId1))
}
private val ProtocolLogic<*>.otherParty: Party
get() = javaClass.getDeclaredField("otherParty").apply { isAccessible = true }.get(this) as Party

View File

@ -2,12 +2,18 @@ package com.r3corda.testing.node
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.node.PhysicalLocation import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.KeyManagementService import com.r3corda.core.node.services.KeyManagementService
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.WalletService import com.r3corda.core.node.services.WalletService
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.testing.InMemoryWalletService import com.r3corda.core.testing.InMemoryWalletService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
@ -18,6 +24,7 @@ import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.ServiceRequestMessage
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import java.nio.file.Files import java.nio.file.Files
@ -130,6 +137,25 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// This does not indirect through the NodeInfo object so it can be called before the node is started. // This does not indirect through the NodeInfo object so it can be called before the node is started.
// It is used from the network visualiser tool. // It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!! @Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun send(topic: String, target: MockNode, payload: Any) {
services.networkService.send(TopicSession(topic), payload, target.info.address)
}
inline fun <reified T : Any> receive(topic: String, sessionId: Long): ListenableFuture<T> {
val receive = SettableFuture.create<T>()
services.networkService.runOnNextMessage(topic, sessionId) {
receive.set(it.data.deserialize<T>())
}
return receive
}
inline fun <reified T : Any> sendAndReceive(topic: String,
target: MockNode,
payload: ServiceRequestMessage): ListenableFuture<T> {
send(topic, target, payload)
return receive(topic, payload.sessionID)
}
} }
/** Returns a node, optionally created by the passed factory method. */ /** Returns a node, optionally created by the passed factory method. */