Removed session IDs from the send and receive methods of ProtocolLogic and are now partially managed by HandshakeMessage

This commit is contained in:
Shams Asari 2016-09-13 17:37:42 +01:00
parent f314bab6c8
commit 8ea20dd0d2
32 changed files with 539 additions and 519 deletions

View File

@ -9,7 +9,6 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.TransactionBuilder
@ -58,19 +57,17 @@ object TwoPartyTradeProtocol {
class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount<Currency>,
val sellerOwnerKey: PublicKey,
val sessionID: Long
val sellerOwnerKey: PublicKey
)
class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherSide: Party,
open class Seller(val otherParty: Party,
val notaryNode: NodeInfo,
val assetToSell: StateAndRef<OwnableState>,
val price: Amount<Currency>,
val myKeyPair: KeyPair,
val buyerSessionID: Long,
override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() {
companion object {
@ -109,12 +106,10 @@ object TwoPartyTradeProtocol {
private fun receiveAndCheckProposedTransaction(): SignedTransaction {
progressTracker.currentStep = AWAITING_PROPOSAL
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID)
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, buyerSessionID, sessionID, hello)
val maybeSTX = sendAndReceive<SignedTransaction>(otherParty, hello)
progressTracker.currentStep = VERIFYING
@ -127,7 +122,7 @@ object TwoPartyTradeProtocol {
// Download and check all the things that this transaction depends on and verify it is contract-valid,
// even though it is missing signatures.
subProtocol(ResolveTransactionsProtocol(wtx, otherSide))
subProtocol(ResolveTransactionsProtocol(wtx, otherParty))
if (wtx.outputs.map { it.data }.sumCashBy(myKeyPair.public).withoutIssuer() != price)
throw IllegalArgumentException("Transaction is not sending us the right amount of cash")
@ -159,16 +154,15 @@ object TwoPartyTradeProtocol {
logger.trace { "Built finished transaction, sending back to secondary!" }
send(otherSide, buyerSessionID, SignaturesFromSeller(ourSignature, notarySignature))
send(otherParty, SignaturesFromSeller(ourSignature, notarySignature))
return fullySigned
}
}
open class Buyer(val otherSide: Party,
open class Buyer(val otherParty: Party,
val notary: Party,
val acceptablePrice: Amount<Currency>,
val typeToBuy: Class<out OwnableState>,
val sessionID: Long) : ProtocolLogic<SignedTransaction>() {
val typeToBuy: Class<out OwnableState>) : ProtocolLogic<SignedTransaction>() {
object RECEIVING : ProgressTracker.Step("Waiting for seller trading info")
@ -189,7 +183,7 @@ object TwoPartyTradeProtocol {
val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest)
val stx = signWithOurKeys(cashSigningPubKeys, ptx)
val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID)
val signatures = swapSignaturesWithSeller(stx)
logger.trace { "Got signatures from seller, verifying ... " }
@ -204,7 +198,7 @@ object TwoPartyTradeProtocol {
private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID.
val maybeTradeRequest = receive<SellerTradeInfo>(sessionID)
val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
progressTracker.currentStep = VERIFYING
maybeTradeRequest.unwrap {
@ -213,8 +207,6 @@ object TwoPartyTradeProtocol {
val assetTypeName = asset.javaClass.name
logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" }
// Check the start message for acceptability.
check(it.sessionID > 0)
if (it.price > acceptablePrice)
throw UnacceptablePriceException(it.price)
if (!typeToBuy.isInstance(asset))
@ -222,20 +214,20 @@ object TwoPartyTradeProtocol {
// Check the transaction that contains the state which is being resolved.
// We only have a hash here, so if we don't know it already, we have to ask for it.
subProtocol(ResolveTransactionsProtocol(setOf(it.assetForSale.ref.txhash), otherSide))
subProtocol(ResolveTransactionsProtocol(setOf(it.assetForSale.ref.txhash), otherParty))
return it
}
}
@Suspendable
private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller {
private fun swapSignaturesWithSeller(stx: SignedTransaction): SignaturesFromSeller {
progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to seller" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromSeller>(otherSide, theirSessionID, sessionID, stx).unwrap { it }
return sendAndReceive<SignaturesFromSeller>(otherParty, stx).unwrap { it }
}
private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {

View File

@ -2,6 +2,7 @@ package com.r3corda.core
import com.google.common.base.Throwables
import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
@ -15,6 +16,7 @@ import java.nio.file.Path
import java.time.Duration
import java.time.temporal.Temporal
import java.util.concurrent.Executor
import java.util.concurrent.Future
import java.util.concurrent.locks.ReentrantLock
import java.util.zip.ZipInputStream
import kotlin.concurrent.withLock
@ -67,6 +69,7 @@ fun <T> ListenableFuture<T>.failure(executor: Executor, body: (Throwable) -> Uni
}
}
infix fun <F, T> Future<F>.map(mapper: (F) -> T): Future<T> = Futures.lazyTransform(this) { mapper(it!!) }
infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) }

View File

@ -3,9 +3,13 @@ package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.debug
import com.r3corda.protocols.HandshakeMessage
import org.slf4j.Logger
import java.util.*
/**
* A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you
@ -45,32 +49,78 @@ abstract class ProtocolLogic<out T> {
*/
protected abstract val topic: String
private val sessions = HashMap<Party, Session>()
/**
* If a node receives a [HandshakeMessage] it needs to call this method on the initiated receipt protocol to enable
* communication between it and the sender protocol. Calling this method, and other initiation steps, are already
* handled by AbstractNodeService.addProtocolHandler.
*/
fun registerSession(receivedHandshake: HandshakeMessage) {
// Note that the send and receive session IDs are swapped
addSession(receivedHandshake.replyToParty, receivedHandshake.receiveSessionID, receivedHandshake.sendSessionID)
}
// Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
payload: Any): UntrustworthyData<T> {
return psm.sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, payload, T::class.java)
inline fun <reified T : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<T> {
return sendAndReceive(otherParty, payload, T::class.java)
}
inline fun <reified T : Any> receive(sessionIDForReceive: Long): UntrustworthyData<T> {
return receive(sessionIDForReceive, T::class.java)
@Suspendable
fun <T : Any> sendAndReceive(otherParty: Party, payload: Any, receiveType: Class<T>): UntrustworthyData<T> {
val sendSessionId = getSendSessionId(otherParty, payload)
val receiveSessionId = getReceiveSessionId(otherParty)
return psm.sendAndReceive(topic, otherParty, sendSessionId, receiveSessionId, payload, receiveType)
}
@Suspendable fun <T : Any> receive(sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
return psm.receive(topic, sessionIDForReceive, receiveType)
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(otherParty, T::class.java)
@Suspendable
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>): UntrustworthyData<T> {
return psm.receive(topic, getReceiveSessionId(otherParty), receiveType)
}
@Suspendable fun send(destination: Party, sessionID: Long, payload: Any) {
psm.send(topic, destination, sessionID, payload)
@Suspendable
fun send(otherParty: Party, payload: Any) {
psm.send(topic, otherParty, getSendSessionId(otherParty, payload), payload)
}
private fun addSession(party: Party, sendSesssionId: Long, receiveSessionId: Long) {
if (party in sessions) {
logger.debug { "Existing session with party $party to be overwritten by new one" }
}
sessions[party] = Session(sendSesssionId, receiveSessionId)
}
private fun getSendSessionId(otherParty: Party, payload: Any): Long {
return if (payload is HandshakeMessage) {
addSession(otherParty, payload.sendSessionID, payload.receiveSessionID)
DEFAULT_SESSION_ID
} else {
sessions[otherParty]?.sendSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
}
private fun getReceiveSessionId(otherParty: Party): Long {
return sessions[otherParty]?.receiveSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
/**
* Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the
* [ProtocolStateMachine] and then calling the [call] method.
* @param inheritParentSessions In certain situations the subprotocol needs to inherit and use the same open
* sessions of the parent. However in most cases this is not desirable as it prevents the subprotocol from
* communicating with the same party on a different topic. For this reason the default value is false.
*/
@Suspendable fun <R> subProtocol(subLogic: ProtocolLogic<R>): R {
@JvmOverloads
@Suspendable
fun <R> subProtocol(subLogic: ProtocolLogic<R>, inheritParentSessions: Boolean = false): R {
subLogic.psm = psm
if (inheritParentSessions) {
subLogic.sessions.putAll(sessions)
}
maybeWireUpProgressTracking(subLogic)
val r = subLogic.call()
// It's easy to forget this when writing protocols so we just step it to the DONE state when it completes.
@ -106,4 +156,6 @@ abstract class ProtocolLogic<out T> {
@Suspendable
abstract fun call(): T
private data class Session(val sendSessionId: Long, val receiveSessionId: Long)
}

View File

@ -7,8 +7,6 @@ import com.r3corda.core.contracts.StateRef
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction
@ -36,9 +34,9 @@ abstract class AbstractStateReplacementProtocol<T> {
val stx: SignedTransaction
}
data class Handshake(val sessionIdForSend: Long,
override val replyToParty: Party,
override val sessionID: Long) : PartyRequestMessage
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>,
val modification: T,
@ -77,36 +75,31 @@ abstract class AbstractStateReplacementProtocol<T> {
@Suspendable
private fun collectSignatures(participants: List<PublicKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> {
val sessions = mutableMapOf<NodeInfo, Long>()
val participantSignatures = participants.map {
val parties = participants.map {
val participantNode = serviceHub.networkMapCache.getNodeByPublicKey(it) ?:
throw IllegalStateException("Participant $it to state $originalState not found on the network")
val sessionIdForSend = random63BitValue()
sessions[participantNode] = sessionIdForSend
getParticipantSignature(participantNode, stx, sessionIdForSend)
participantNode.identity
}
val participantSignatures = parties.map { getParticipantSignature(it, stx) }
val allSignatures = participantSignatures + getNotarySignature(stx)
sessions.forEach { send(it.key.identity, it.value, allSignatures) }
parties.forEach { send(it, allSignatures) }
return allSignatures
}
@Suspendable
private fun getParticipantSignature(node: NodeInfo, stx: SignedTransaction, sessionIdForSend: Long): DigitalSignature.WithKey {
val sessionIdForReceive = random63BitValue()
private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey {
val proposal = assembleProposal(originalState.ref, modification, stx)
val handshake = Handshake(sessionIdForSend, serviceHub.storageService.myLegalIdentity, sessionIdForReceive)
sendAndReceive<Ack>(node.identity, 0, sessionIdForReceive, handshake)
send(party, Handshake(serviceHub.storageService.myLegalIdentity))
val response = sendAndReceive<Result>(node.identity, sessionIdForSend, sessionIdForReceive, proposal)
val response = sendAndReceive<Result>(party, proposal)
val participantSignature = response.unwrap {
if (it.sig == null) throw StateReplacementException(it.error!!)
else {
check(it.sig.by == node.identity.owningKey) { "Not signed by the required participant" }
check(it.sig.by == party.owningKey) { "Not signed by the required participant" }
it.sig.verifyWithECDSA(stx.txBits)
it.sig
}
@ -123,9 +116,7 @@ abstract class AbstractStateReplacementProtocol<T> {
}
abstract class Acceptor<T>(val otherSide: Party,
val sessionIdForSend: Long,
val sessionIdForReceive: Long,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object {
object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal")
@ -140,7 +131,7 @@ abstract class AbstractStateReplacementProtocol<T> {
@Suspendable
override fun call() {
progressTracker.currentStep = VERIFYING
val maybeProposal: UntrustworthyData<Proposal<T>> = receive(sessionIdForReceive)
val maybeProposal: UntrustworthyData<Proposal<T>> = receive(otherSide)
try {
val stx: SignedTransaction = maybeProposal.unwrap { verifyProposal(maybeProposal).stx }
verifyTx(stx)
@ -163,7 +154,7 @@ abstract class AbstractStateReplacementProtocol<T> {
val mySignature = sign(stx)
val response = Result.noError(mySignature)
val swapSignatures = sendAndReceive<List<DigitalSignature.WithKey>>(otherSide, sessionIdForSend, sessionIdForReceive, response)
val swapSignatures = sendAndReceive<List<DigitalSignature.WithKey>>(otherSide, response)
// TODO: This step should not be necessary, as signatures are re-checked in verifySignatures.
val allSignatures = swapSignatures.unwrap { signatures ->
@ -180,7 +171,7 @@ abstract class AbstractStateReplacementProtocol<T> {
private fun reject(e: StateReplacementRefused) {
progressTracker.currentStep = REJECTING
val response = Result.withError(e)
send(otherSide, sessionIdForSend, response)
send(otherSide, response)
}
/**

View File

@ -2,12 +2,10 @@ package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.serialize
import java.util.*
import com.r3corda.core.transactions.SignedTransaction
/**
@ -33,12 +31,11 @@ class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
override val topic: String = TOPIC
data class NotifyTxRequestMessage(
val tx: SignedTransaction,
val events: Set<ClientToServiceCommand>,
override val replyToParty: Party,
override val sessionID: Long
) : PartyRequestMessage
data class NotifyTxRequestMessage(val tx: SignedTransaction,
val events: Set<ClientToServiceCommand>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
@Suspendable
override fun call() {
@ -48,9 +45,11 @@ class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
// TODO: Messaging layer should handle this broadcast for us (although we need to not be sending
// session ID, for that to work, as well).
participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant ->
val sessionID = random63BitValue()
val msg = NotifyTxRequestMessage(notarisedTransaction, events, serviceHub.storageService.myLegalIdentity, sessionID)
send(participant, 0, msg)
val msg = NotifyTxRequestMessage(
notarisedTransaction,
events,
serviceHub.storageService.myLegalIdentity)
send(participant, msg)
}
}
}

View File

@ -7,6 +7,8 @@ import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch
import com.r3corda.protocols.FetchDataProtocol.HashNotFound
import java.util.*
/**
@ -33,7 +35,10 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
class HashNotFound(val requested: SecureHash) : BadAnswer()
class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer()
data class Request(val hashes: List<SecureHash>, override val replyToParty: Party, override val sessionID: Long) : PartyRequestMessage
data class Request(val hashes: List<SecureHash>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>)
@Suspendable
@ -46,10 +51,9 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
} else {
logger.trace("Requesting ${toFetch.size} dependency(s) for verification")
val sid = random63BitValue()
val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity, sid)
val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity)
// TODO: Support "large message" response streaming so response sizes are not limited by RAM.
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, 0, sid, fetchReq)
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, fetchReq)
// Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(maybeItems, toFetch)
maybeWriteToDisk(downloaded)

View File

@ -53,10 +53,8 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
}
class Acceptor(otherSide: Party,
sessionIdForSend: Long,
sessionIdForReceive: Long,
override val progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Acceptor<Party>(otherSide, sessionIdForSend, sessionIdForReceive) {
: AbstractStateReplacementProtocol.Acceptor<Party>(otherSide) {
override val topic: String get() = TOPIC

View File

@ -1,10 +1,6 @@
package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.contracts.Timestamp
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SignedData
@ -13,11 +9,12 @@ import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessException
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.core.noneOrSingle
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData
import java.security.PublicKey
@ -56,14 +53,10 @@ object NotaryProtocol {
notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary")
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" }
val sendSessionID = random63BitValue()
val receiveSessionID = random63BitValue()
val handshake = Handshake(serviceHub.storageService.myLegalIdentity, sendSessionID, receiveSessionID)
sendAndReceive<Ack>(notaryParty, 0, receiveSessionID, handshake)
sendAndReceive<Ack>(notaryParty, Handshake(serviceHub.storageService.myLegalIdentity))
val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity)
val response = sendAndReceive<Result>(notaryParty, sendSessionID, receiveSessionID, request)
val response = sendAndReceive<Result>(notaryParty, request)
val notaryResult = validateResponse(response)
return notaryResult.sig ?: throw NotaryException(notaryResult.error!!)
@ -96,8 +89,6 @@ object NotaryProtocol {
* TODO: the notary service should only be able to see timestamp commands and inputs
*/
open class Service(val otherSide: Party,
val sendSessionID: Long,
val receiveSessionID: Long,
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() {
@ -105,7 +96,7 @@ object NotaryProtocol {
@Suspendable
override fun call() {
val (stx, reqIdentity) = receive<SignRequest>(receiveSessionID).unwrap { it }
val (stx, reqIdentity) = sendAndReceive<SignRequest>(otherSide, Ack).unwrap { it }
val wtx = stx.tx
val result = try {
@ -119,7 +110,7 @@ object NotaryProtocol {
Result.withError(e.error)
}
send(otherSide, sendSessionID, result)
send(otherSide, result)
}
private fun validateTimestamp(tx: WireTransaction) {
@ -157,10 +148,9 @@ object NotaryProtocol {
}
}
data class Handshake(
override val replyToParty: Party,
val sendSessionID: Long,
override val sessionID: Long) : PartyRequestMessage
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/** TODO: The caller must authenticate instead of just specifying its identity */
data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party)
@ -174,19 +164,15 @@ object NotaryProtocol {
interface Factory {
fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service
}
object DefaultFactory : Factory {
override fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service {
return Service(otherSide, sendSessionID, receiveSessionID, timestampChecker, uniquenessProvider)
return Service(otherSide, timestampChecker, uniquenessProvider)
}
}
}

View File

@ -3,12 +3,12 @@ package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.Fix
import com.r3corda.core.contracts.FixOf
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.RatesFixProtocol.FixOutOfRange
@ -47,8 +47,16 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>, override val replyToParty: Party, override val sessionID: Long, val deadline: Instant) : PartyRequestMessage
data class SignRequest(val tx: WireTransaction, override val replyToParty: Party, override val sessionID: Long) : PartyRequestMessage
data class QueryRequest(val queries: List<FixOf>,
val deadline: Instant,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class SignRequest(val tx: WireTransaction,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
@Suspendable
override fun call() {
@ -80,10 +88,9 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
@Suspendable
private fun sign(): DigitalSignature.LegallyIdentifiable {
val sessionID = random63BitValue()
val wtx = tx.toWireTransaction()
val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity, sessionID)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, 0, sessionID, req)
val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, req)
return resp.unwrap { sig ->
check(sig.signer == oracle)
@ -94,11 +101,10 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
@Suspendable
private fun query(): Fix {
val sessionID = random63BitValue()
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
val req = QueryRequest(listOf(fixOf), serviceHub.storageService.myLegalIdentity, sessionID, deadline)
val req = QueryRequest(listOf(fixOf), deadline, serviceHub.storageService.myLegalIdentity)
// TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, 0, sessionID, req)
val resp = sendAndReceive<ArrayList<Fix>>(oracle, req)
return resp.unwrap {
val fix = it.first()

View File

@ -32,4 +32,19 @@ interface PartyRequestMessage : ServiceRequestMessage {
override fun getReplyTo(networkMapCache: NetworkMapCache): MessageRecipients {
return networkMapCache.partyNodes.single { it.identity == replyToParty }.address
}
}
/**
* A Handshake message is sent to initiate communication between two protocol instances. It contains the two session IDs
* the two protocols will need to communicate.
* Note: This is a temperary interface and will be removed once the protocol session work is implemented.
*/
interface HandshakeMessage : PartyRequestMessage {
val sendSessionID: Long
val receiveSessionID: Long
@Deprecated("sessionID functions as receiveSessionID but it's recommended to use the later for clarity",
replaceWith = ReplaceWith("receiveSessionID"))
override val sessionID: Long get() = receiveSessionID
}

View File

@ -7,7 +7,6 @@ import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
@ -48,11 +47,7 @@ object TwoPartyDealProtocol {
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
data class Handshake<out T>(
val payload: T,
val publicKey: PublicKey,
val sessionID: Long
)
data class Handshake<out T>(val payload: T, val publicKey: PublicKey)
class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable)
@ -80,19 +75,16 @@ object TwoPartyDealProtocol {
abstract val payload: U
abstract val notaryNode: NodeInfo
abstract val otherSide: Party
abstract val otherSessionID: Long
abstract val otherParty: Party
abstract val myKeyPair: KeyPair
@Suspendable
fun getPartialTransaction(): UntrustworthyData<SignedTransaction> {
progressTracker.currentStep = AWAITING_PROPOSAL
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
val hello = Handshake(payload, myKeyPair.public, sessionID)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, otherSessionID, sessionID, hello)
val hello = Handshake(payload, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherParty, hello)
return maybeSTX
}
@ -132,7 +124,7 @@ object TwoPartyDealProtocol {
// Download and check all the transactions that this transaction depends on, but do not check this
// transaction itself.
val dependencyTxIDs = stx.tx.inputs.map { it.txhash }.toSet()
subProtocol(ResolveTransactionsProtocol(dependencyTxIDs, otherSide))
subProtocol(ResolveTransactionsProtocol(dependencyTxIDs, otherParty))
}
@Suspendable
@ -156,7 +148,7 @@ object TwoPartyDealProtocol {
if (regulators.isNotEmpty()) {
// Copy the transaction to every regulator in the network. This is obviously completely bogus, it's
// just for demo purposes.
regulators.forEach { send(it.identity, DEFAULT_SESSION_ID, fullySigned) }
regulators.forEach { send(it.identity, fullySigned) }
}
return fullySigned
@ -181,7 +173,7 @@ object TwoPartyDealProtocol {
logger.trace { "Built finished transaction, sending back to other party!" }
send(otherSide, otherSessionID, SignaturesFromPrimary(ourSignature, notarySignature))
send(otherParty, SignaturesFromPrimary(ourSignature, notarySignature))
return fullySigned
}
}
@ -207,8 +199,7 @@ object TwoPartyDealProtocol {
override val topic: String get() = DEAL_TOPIC
abstract val otherSide: Party
abstract val sessionID: Long
abstract val otherParty: Party
@Suspendable
override fun call(): SignedTransaction {
@ -218,7 +209,7 @@ object TwoPartyDealProtocol {
val (ptx, additionalSigningPubKeys) = assembleSharedTX(handshake)
val stx = signWithOurKeys(additionalSigningPubKeys, ptx)
val signatures = swapSignaturesWithPrimary(stx, handshake.sessionID)
val signatures = swapSignaturesWithPrimary(stx)
logger.trace { "Got signatures from other party, verifying ... " }
@ -238,7 +229,7 @@ object TwoPartyDealProtocol {
private fun receiveAndValidateHandshake(): Handshake<U> {
progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID.
val handshake = receive<Handshake<U>>(sessionID)
val handshake = receive<Handshake<U>>(otherParty)
progressTracker.currentStep = VERIFYING
handshake.unwrap {
@ -247,13 +238,13 @@ object TwoPartyDealProtocol {
}
@Suspendable
private fun swapSignaturesWithPrimary(stx: SignedTransaction, theirSessionID: Long): SignaturesFromPrimary {
private fun swapSignaturesWithPrimary(stx: SignedTransaction): SignaturesFromPrimary {
progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to other party" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromPrimary>(otherSide, theirSessionID, sessionID, stx).unwrap { it }
return sendAndReceive<SignaturesFromPrimary>(otherParty, stx).unwrap { it }
}
private fun signWithOurKeys(signingPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
@ -273,11 +264,10 @@ object TwoPartyDealProtocol {
/**
* One side of the protocol for inserting a pre-agreed deal.
*/
open class Instigator<out T : DealState>(override val otherSide: Party,
open class Instigator<out T : DealState>(override val otherParty: Party,
val notary: Party,
override val payload: T,
override val myKeyPair: KeyPair,
override val otherSessionID: Long,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<T>() {
override val notaryNode: NodeInfo get() =
@ -287,10 +277,9 @@ object TwoPartyDealProtocol {
/**
* One side of the protocol for inserting a pre-agreed deal.
*/
open class Acceptor<T : DealState>(override val otherSide: Party,
open class Acceptor<T : DealState>(override val otherParty: Party,
val notary: Party,
val dealToBuy: T,
override val sessionID: Long,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<T>() {
override fun validateHandshake(handshake: Handshake<T>): Handshake<T> {
@ -299,8 +288,6 @@ object TwoPartyDealProtocol {
val otherKey = handshake.publicKey
logger.trace { "Got deal request for: ${handshake.payload.ref}" }
// Check the start message for acceptability.
check(handshake.sessionID > 0)
check(dealToBuy == deal)
// We need to substitute in the new public keys for the Parties
@ -335,11 +322,9 @@ object TwoPartyDealProtocol {
* of the protocol that is run by the party with the fixed leg of swap deal, which is the basis for deciding
* who does what in the protocol.
*/
class Fixer(val initiation: FixingSessionInitiation, override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() {
override val sessionID: Long get() = initiation.sessionID
override val otherSide: Party get() = initiation.sender
class Fixer(override val otherParty: Party,
val oracleType: ServiceType,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() {
private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState
@ -347,16 +332,12 @@ object TwoPartyDealProtocol {
override fun validateHandshake(handshake: Handshake<StateRef>): Handshake<StateRef> {
logger.trace { "Got fixing request for: ${handshake.payload}" }
// Check the handshake and initiation for acceptability.
check(handshake.sessionID > 0)
txState = serviceHub.loadState(handshake.payload)
deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it.
// TODO: this is in no way secure and will be replaced by general session initiation logic in the future
val myName = serviceHub.storageService.myLegalIdentity.name
val otherParty = deal.parties.filter { it.name != myName }.single()
check(otherParty == initiation.party)
// Also check we are one of the parties
deal.parties.filter { it.name == myName }.single()
@ -376,7 +357,7 @@ object TwoPartyDealProtocol {
val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(initiation.oracleType).first()
val oracle = serviceHub.networkMapCache.get(oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable
@ -401,7 +382,6 @@ object TwoPartyDealProtocol {
* does what in the protocol.
*/
class Floater(override val payload: StateRef,
override val otherSessionID: Long,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<StateRef>() {
@Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
@ -415,8 +395,8 @@ object TwoPartyDealProtocol {
return serviceHub.keyManagementService.toKeyPair(publicKey)
}
override val otherSide: Party get() {
// TODO: what happens if there's no node? Move to messaging taking Party and then handled in messaging layer
override val otherParty: Party get() {
// TODO otherParty is sortedParties[1] from FixingRoleDecider.call and so can be passed in as a c'tor param
val myName = serviceHub.storageService.myLegalIdentity.name
return dealToFix.state.data.parties.filter { it.name != myName }.single()
}
@ -427,7 +407,11 @@ object TwoPartyDealProtocol {
/** Used to set up the session between [Floater] and [Fixer] */
data class FixingSessionInitiation(val sessionID: Long, val party: Party, val sender: Party, val timeout: Duration, val oracleType: ServiceType)
data class FixingSessionInitiation(val timeout: Duration,
val oracleType: ServiceType,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
@ -459,18 +443,17 @@ object TwoPartyDealProtocol {
val sortedParties = fixableDeal.parties.sortedBy { it.name }
val oracleType = fixableDeal.oracleType
if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) {
// Generate sessionID
val sessionID = random63BitValue()
val initation = FixingSessionInitiation(sessionID, sortedParties[0], serviceHub.storageService.myLegalIdentity, timeout, oracleType)
val initation = FixingSessionInitiation(
timeout,
oracleType,
serviceHub.storageService.myLegalIdentity)
// Send initiation to other side to launch one side of the fixing protocol (the Fixer).
send(sortedParties[1], DEFAULT_SESSION_ID, initation)
send(sortedParties[1], initation)
// Then start the other side of the fixing protocol.
val protocol = Floater(ref, sessionID)
subProtocol(protocol)
subProtocol(Floater(ref), inheritParentSessions = true)
}
}
}
}
}

View File

@ -16,10 +16,10 @@ import java.security.SignatureException
* indeed valid.
*/
class ValidatingNotaryProtocol(otherSide: Party,
sessionIdForSend: Long,
sessionIdForReceive: Long,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryProtocol.Service(otherSide, sessionIdForSend, sessionIdForReceive, timestampChecker, uniquenessProvider) {
uniquenessProvider: UniquenessProvider) :
NotaryProtocol.Service(otherSide, timestampChecker, uniquenessProvider) {
@Suspendable
override fun beforeCommit(stx: SignedTransaction, reqIdentity: Party) {
try {

View File

@ -24,7 +24,8 @@ class BroadcastTransactionProtocolTest {
tx = SignedTransactionGenerator().generate(random, status),
events = setOf(),
replyToParty = PartyGenerator().generate(random, status),
sessionID = random.nextLong()
sendSessionID = random.nextLong(),
receiveSessionID = random.nextLong()
)
}
}

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange {
class Plugin : CordaPluginRegistry() {
@ -19,18 +18,9 @@ object NotaryChange {
*/
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init {
addMessageHandler(NotaryChangeProtocol.TOPIC,
{ req: AbstractStateReplacementProtocol.Handshake -> handleChangeNotaryRequest(req) }
)
}
private fun handleChangeNotaryRequest(req: AbstractStateReplacementProtocol.Handshake): Ack {
val protocol = NotaryChangeProtocol.Acceptor(
req.replyToParty,
req.sessionID,
req.sessionIdForSend)
services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
return Ack
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
NotaryChangeProtocol.Acceptor(req.replyToParty)
}
}
}
}

View File

@ -1,10 +1,14 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe
@ -14,6 +18,10 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService
/**
@ -57,4 +65,37 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
crossinline handler: (Q) -> R) {
addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
}
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: (ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
}

View File

@ -1,10 +1,11 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
@ -17,12 +18,10 @@ object FixingSessionInitiation {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) {
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init {
services.networkService.addMessageHandler(TwoPartyDealProtocol.FIX_INITIATE_TOPIC, DEFAULT_SESSION_ID) { msg, registration ->
val initiation = msg.data.deserialize<TwoPartyDealProtocol.FixingSessionInitiation>()
val protocol = TwoPartyDealProtocol.Fixer(initiation)
services.startProtocol("fixings", protocol)
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
}
}
}

View File

@ -6,7 +6,6 @@ import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.serialize
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
@ -50,8 +49,7 @@ object DataVending {
myIdentity: Party,
recipient: NodeInfo,
transaction: SignedTransaction) {
val sessionID = random63BitValue()
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity, sessionID)
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
}
}
@ -65,29 +63,29 @@ object DataVending {
{ req: FetchDataProtocol.Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(BroadcastTransactionProtocol.TOPIC,
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage -> handleTXNotification(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
}
private fun handleTXNotification(req: BroadcastTransactionProtocol.NotifyTxRequestMessage): Unit {
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
services.startProtocol("Resolving transactions", ResolveTransactionsProtocol(req.tx, req.replyToParty))
.success {
services.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
addProtocolHandler(
BroadcastTransactionProtocol.TOPIC,
"Resolving transactions",
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
ResolveTransactionsProtocol(req.tx, req.replyToParty)
},
{ future, req ->
future.success {
services.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
}
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {

View File

@ -1,12 +1,12 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC
/**
* A Notary service acts as the final signer of a transaction ensuring two things:
@ -30,19 +30,9 @@ abstract class NotaryService(services: ServiceHubInternal,
abstract val protocolFactory: NotaryProtocol.Factory
init {
addMessageHandler(NotaryProtocol.TOPIC,
{ req: NotaryProtocol.Handshake -> processRequest(req) }
)
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
}
}
private fun processRequest(req: NotaryProtocol.Handshake): Ack {
val protocol = protocolFactory.create(
req.replyToParty,
req.sessionID,
req.sendSessionID,
timestampChecker,
uniquenessProvider)
services.startProtocol(NotaryProtocol.TOPIC, protocol)
return Ack
}
}

View File

@ -19,11 +19,9 @@ class ValidatingNotaryService(services: ServiceHubInternal,
override val protocolFactory = object : NotaryProtocol.Factory {
override fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, sendSessionID, receiveSessionID, timestampChecker, uniquenessProvider)
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
}
}
}

View File

@ -1,6 +1,5 @@
package com.r3corda.node.messaging
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.asset.*
import com.r3corda.contracts.testing.fillWithSomeTestCash
@ -9,27 +8,26 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TransactionStorage
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.protocols.TwoPartyTradeProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
@ -42,6 +40,7 @@ import java.security.KeyPair
import java.security.PublicKey
import java.util.*
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import kotlin.test.assertEquals
@ -56,21 +55,11 @@ import kotlin.test.assertTrue
* We assume that Alice and Bob already found each other via some market, and have agreed the details already.
*/
class TwoPartyTradeProtocolTests {
lateinit var net: MockNetwork
private fun runSeller(smm: StateMachineManager, notary: NodeInfo,
otherSide: Party, assetToSell: StateAndRef<OwnableState>, price: Amount<Currency>,
myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture<SignedTransaction> {
val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, assetToSell, price, myKeyPair, buyerSessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.seller", seller).resultFuture
}
private fun runBuyer(smm: StateMachineManager, notaryNode: NodeInfo,
otherSide: Party, acceptablePrice: Amount<Currency>, typeToBuy: Class<out OwnableState>,
sessionID: Long): ListenableFuture<SignedTransaction> {
val buyer = TwoPartyTradeProtocol.Buyer(otherSide, notaryNode.identity, acceptablePrice, typeToBuy, sessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.buyer", buyer).resultFuture
}
lateinit var notaryNode: MockNetwork.MockNode
lateinit var aliceNode: MockNetwork.MockNode
lateinit var bobNode: MockNetwork.MockNode
@Before
fun before() {
@ -92,10 +81,9 @@ class TwoPartyTradeProtocolTests {
net = MockNetwork(false, true)
ledger {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
@ -103,26 +91,7 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
// We start the Buyer first, as the Seller sends the first message
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val aliceResult = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
@ -139,9 +108,9 @@ class TwoPartyTradeProtocolTests {
@Test
fun `shutdown and restore`() {
ledger {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
var bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.Handle
val networkMapAddr = notaryNode.info.address
@ -153,25 +122,7 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
val aliceFuture = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).second
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
@ -210,7 +161,7 @@ class TwoPartyTradeProtocolTests {
}, true, BOB.name, BOB_KEY)
// Find the future representing the result of this state machine again.
val bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second
val bobFuture = bobNode.smm.findStateMachines(Buyer::class.java).single().second
// And off we go again.
net.runNetwork()
@ -218,7 +169,7 @@ class TwoPartyTradeProtocolTests {
// Bob is now finished and has the same transaction as Alice.
assertThat(bobFuture.get()).isEqualTo(aliceFuture.get())
assertThat(bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java)).isEmpty()
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
@ -250,9 +201,9 @@ class TwoPartyTradeProtocolTests {
@Test
fun `check dependencies of sale asset are resolved`() {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
ledger(aliceNode.services) {
@ -271,27 +222,9 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages
runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()
@ -370,14 +303,25 @@ class TwoPartyTradeProtocolTests {
}
}
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : Pair<Future<SignedTransaction>, Future<SignedTransaction>> {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller)
// We start the Buyer first, as the Seller sends the first message
val buyerResult = bobNode.smm.add("$TOPIC.buyer", buyer).resultFuture
val sellerResult = aliceNode.smm.add("$TOPIC.seller", seller).resultFuture
return Pair(buyerResult, sellerResult)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
bobError: Boolean,
aliceError: Boolean,
expectedMessageSubstring: String
) {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val issuer = MEGA_CORP.ref(1, 2, 3)
val bobKey = bobNode.keyManagement.freshKey()
@ -388,27 +332,9 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(bobsBadCash, bobNode.services, bobNode.storage.myLegalIdentityKey, bobNode.storage.myLegalIdentityKey)
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages
val aliceResult = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()

View File

@ -1,18 +1,30 @@
package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.map
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NetworkMapService.*
import com.r3corda.node.services.network.NetworkMapService.Companion.FETCH_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.PUSH_ACK_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before
import org.junit.Test
import java.security.PrivateKey
import java.time.Instant
import java.util.concurrent.Future
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
@ -36,7 +48,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm the service contains only its own node
assertEquals(1, service.nodes.count())
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Register the second node
var seq = 1L
@ -44,64 +56,22 @@ class InMemoryNetworkMapServiceTest {
val nodeKey = registerNode.storage.myLegalIdentityKey
val addChange = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires)
val addWireChange = addChange.toWire(nodeKey.private)
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count())
assertEquals(mapServiceNode.info, service.processQueryRequest(NetworkMapService.QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assertEquals(mapServiceNode.info, service.processQueryRequest(QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Re-registering should be a no-op
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count())
// Confirm that de-registering the node succeeds and drops it from the node lists
val removeChange = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val removeWireChange = removeChange.toWire(nodeKey.private)
assert(service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assert(service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Trying to de-register a node that doesn't exist should fail
assert(!service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
}
class TestAcknowledgePSM(val server: NodeInfo, val mapVersion: Int) : ProtocolLogic<Unit>() {
override val topic: String get() = NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC
@Suspendable
override fun call() {
val req = NetworkMapService.UpdateAcknowledge(mapVersion, serviceHub.networkService.myAddress)
send(server.identity, 0, req)
}
}
class TestFetchPSM(val server: NodeInfo, val subscribe: Boolean, val ifChangedSinceVersion: Int? = null)
: ProtocolLogic<Collection<NodeRegistration>?>() {
override val topic: String get() = NetworkMapService.FETCH_PROTOCOL_TOPIC
@Suspendable
override fun call(): Collection<NodeRegistration>? {
val sessionID = random63BitValue()
val req = NetworkMapService.FetchMapRequest(subscribe, ifChangedSinceVersion, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.FetchMapResponse>(server.identity, 0, sessionID, req).unwrap { it.nodes }
}
}
class TestRegisterPSM(val server: NodeInfo, val reg: NodeRegistration, val privateKey: PrivateKey)
: ProtocolLogic<NetworkMapService.RegistrationResponse>() {
override val topic: String get() = NetworkMapService.REGISTER_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.RegistrationResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.RegistrationRequest(reg.toWire(privateKey), serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.RegistrationResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
}
class TestSubscribePSM(val server: NodeInfo, val subscribe: Boolean)
: ProtocolLogic<NetworkMapService.SubscribeResponse>() {
override val topic: String get() = NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.SubscribeResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.SubscribeRequest(subscribe, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.SubscribeResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
assert(!service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
}
@Test
@ -113,7 +83,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm all nodes have registered themselves
network.runNetwork()
var fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false))
var fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork()
assertEquals(2, fetchPsm.get()?.count())
@ -122,12 +92,12 @@ class InMemoryNetworkMapServiceTest {
val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD
val seq = 2L
val reg = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val registerPsm = registerNode.services.startProtocol(NetworkMapService.REGISTER_PROTOCOL_TOPIC, TestRegisterPSM(mapServiceNode.info, reg, nodeKey.private))
val registerPsm = registration(registerNode, mapServiceNode, reg, nodeKey.private)
network.runNetwork()
assertTrue(registerPsm.get().success)
// Now only map service node should be registered
fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false))
fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork()
assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single())
}
@ -139,8 +109,7 @@ class InMemoryNetworkMapServiceTest {
// Test subscribing to updates
network.runNetwork()
val subscribePsm = registerNode.services.startProtocol(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC,
TestSubscribePSM(mapServiceNode.info, true))
val subscribePsm = subscribe(registerNode, mapServiceNode, true)
network.runNetwork()
subscribePsm.get()
@ -161,10 +130,8 @@ class InMemoryNetworkMapServiceTest {
assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
// Send in an acknowledgment and verify the count goes down
val acknowledgePsm = registerNode.services.startProtocol(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC,
TestAcknowledgePSM(mapServiceNode.info, startingMapVersion + 1))
updateAcknowlege(registerNode, mapServiceNode, startingMapVersion + 1)
network.runNetwork()
acknowledgePsm.get()
assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
@ -181,4 +148,25 @@ class InMemoryNetworkMapServiceTest {
}
}
}
}
private fun registration(registerNode: MockNode, mapServiceNode: MockNode, reg: NodeRegistration, privateKey: PrivateKey): ListenableFuture<RegistrationResponse> {
val req = RegistrationRequest(reg.toWire(privateKey), registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<RegistrationResponse>(REGISTER_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun subscribe(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean): ListenableFuture<SubscribeResponse> {
val req = SubscribeRequest(subscribe, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<SubscribeResponse>(SUBSCRIPTION_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun updateAcknowlege(registerNode: MockNode, mapServiceNode: MockNode, mapVersion: Int) {
val req = UpdateAcknowledge(mapVersion, registerNode.services.networkService.myAddress)
registerNode.send(PUSH_ACK_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun fetchMap(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean, ifChangedSinceVersion: Int? = null): Future<Collection<NodeRegistration>?> {
val req = FetchMapRequest(subscribe, ifChangedSinceVersion, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<FetchMapResponse>(FETCH_PROTOCOL_TOPIC, mapServiceNode, req).map { it.nodes }
}
}

View File

@ -1,33 +1,36 @@
package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.newSecureRandom
import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.OpaqueBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.DUMMY_PUBKEY_1
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.monitor.*
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.IN_EVENT_TOPIC
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.REGISTER_TOPIC
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.testing.*
import com.r3corda.testing.expect
import com.r3corda.testing.expectEvents
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import com.r3corda.testing.parallel
import com.r3corda.testing.sequence
import org.junit.Before
import org.junit.Test
import rx.subjects.PublishSubject
import rx.subjects.ReplaySubject
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
/**
* Unit tests for the wallet monitoring service.
@ -43,32 +46,13 @@ class WalletMonitorServiceTests {
/**
* Authenticate the register node with the monitor service node.
*/
private fun authenticate(monitorServiceNode: MockNetwork.MockNode, registerNode: MockNetwork.MockNode): Long {
private fun authenticate(monitorServiceNode: MockNode, registerNode: MockNode): Long {
network.runNetwork()
val sessionID = random63BitValue()
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC,
TestRegisterPSM(monitorServiceNode.info, sessionID))
val sessionId = random63BitValue()
val authenticatePsm = register(registerNode, monitorServiceNode, sessionId)
network.runNetwork()
authenticatePsm.get(1, TimeUnit.SECONDS)
return sessionID
}
class TestReceiveWalletUpdatePSM(val sessionID: Long)
: ProtocolLogic<ServiceToClientEvent.OutputState>() {
override val topic: String get() = WalletMonitorService.IN_EVENT_TOPIC
@Suspendable
override fun call(): ServiceToClientEvent.OutputState
= receive<ServiceToClientEvent.OutputState>(sessionID).unwrap { it }
}
class TestRegisterPSM(val server: NodeInfo, val sessionID: Long)
: ProtocolLogic<RegisterResponse>() {
override val topic: String get() = WalletMonitorService.REGISTER_TOPIC
@Suspendable
override fun call(): RegisterResponse {
val req = RegisterRequest(serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<RegisterResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
return sessionId
}
/**
@ -79,9 +63,7 @@ class WalletMonitorServiceTests {
val (monitorServiceNode, registerNode) = network.createTwoNodes()
network.runNetwork()
val sessionID = random63BitValue()
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC,
TestRegisterPSM(monitorServiceNode.info, sessionID))
val authenticatePsm = register(registerNode, monitorServiceNode, random63BitValue())
network.runNetwork()
val result = authenticatePsm.get(1, TimeUnit.SECONDS)
assertTrue(result.success)
@ -94,8 +76,7 @@ class WalletMonitorServiceTests {
fun `event received`() {
val (monitorServiceNode, registerNode) = network.createTwoNodes()
val sessionID = authenticate(monitorServiceNode, registerNode)
var receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC,
TestReceiveWalletUpdatePSM(sessionID))
var receivePsm = receiveWalletUpdate(registerNode, sessionID)
var expected = Wallet.Update(emptySet(), emptySet())
monitorServiceNode.inNodeWalletMonitorService!!.notifyWalletUpdate(expected)
network.runNetwork()
@ -104,8 +85,7 @@ class WalletMonitorServiceTests {
assertEquals(expected.produced, actual.produced)
// Check that states are passed through correctly
receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC,
TestReceiveWalletUpdatePSM(sessionID))
receivePsm = receiveWalletUpdate(registerNode, sessionID)
val consumed = setOf(StateRef(SecureHash.randomSHA256(), 0))
val producedState = TransactionState(DummyContract.SingleOwnerState(newSecureRandom().nextInt(), DUMMY_PUBKEY_1), DUMMY_NOTARY)
val produced = setOf(StateAndRef(producedState, StateRef(SecureHash.randomSHA256(), 0)))
@ -125,7 +105,7 @@ class WalletMonitorServiceTests {
val events = ReplaySubject.create<ServiceToClientEvent>()
val ref = OpaqueBytes(ByteArray(1) {1})
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg ->
registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>())
}
@ -178,7 +158,7 @@ class WalletMonitorServiceTests {
val quantity = 1000L
val events = ReplaySubject.create<ServiceToClientEvent>()
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg ->
registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>())
}
@ -240,4 +220,14 @@ class WalletMonitorServiceTests {
)
}
}
private fun register(registerNode: MockNode, monitorServiceNode: MockNode, sessionId: Long): ListenableFuture<RegisterResponse> {
val req = RegisterRequest(registerNode.services.networkService.myAddress, sessionId)
return registerNode.sendAndReceive<RegisterResponse>(REGISTER_TOPIC, monitorServiceNode, req)
}
private fun receiveWalletUpdate(registerNode: MockNode, sessionId: Long): ListenableFuture<ServiceToClientEvent.OutputState> {
return registerNode.receive<ServiceToClientEvent.OutputState>(IN_EVENT_TOPIC, sessionId)
}
}

View File

@ -6,6 +6,7 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
@ -49,10 +50,12 @@ class StateMachineManagerTests {
@Test
fun `protocol suspended just after receiving payload`() {
val topic = "send-and-receive"
val sessionID = random63BitValue()
val payload = random63BitValue()
node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload))
node2.smm.add("test", ReceiveProtocol(topic, sessionID))
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.smm.add("test", sendProtocol)
node2.smm.add("test", receiveProtocol)
net.runNetwork()
node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
@ -90,19 +93,19 @@ class StateMachineManagerTests {
}
private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic<Unit>() {
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(destination, sessionID, payload)
override fun call() = send(otherParty, payload)
}
private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() {
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null
@Suspendable
override fun doCall() {
receivedPayload = receive<Any>(sessionID).unwrap { it }
receivedPayload = receive<Any>(otherParty).unwrap { it }
}
}

View File

@ -13,27 +13,26 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.Emoji
import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.config.NodeConfigurationFromConfig
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol
import joptsimple.OptionParser
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.security.PublicKey
@ -208,16 +207,21 @@ private fun runBuyer(node: Node, amount: Amount<Currency>) {
// next stage in our building site, we will just auto-generate fake trades to give our nodes something to do.
//
// As the seller initiates the two-party trade protocol, here, we will be the buyer.
node.services.networkService.addMessageHandler(DEMO_TOPIC, DEFAULT_SESSION_ID) { message, registration ->
// We use a simple scenario-specific wrapper protocol to make things happen.
val otherSide = message.data.deserialize<Party>()
val buyer = TraderDemoProtocolBuyer(otherSide, attachmentsPath, amount)
node.services.startProtocol("demo.buyer", buyer)
object : AbstractNodeService(node.services) {
init {
addProtocolHandler(DEMO_TOPIC, "demo.buyer") { handshake: TraderDemoHandshake ->
TraderDemoProtocolBuyer(handshake.replyToParty, attachmentsPath, amount)
}
}
}
}
// We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic.
private data class TraderDemoHandshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
private val DEMO_TOPIC = "initiate.demo.trade"
private class TraderDemoProtocolBuyer(val otherSide: Party,
@ -231,21 +235,17 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
@Suspendable
override fun call() {
// The session ID disambiguates the test trade.
val sessionID = random63BitValue()
progressTracker.currentStep = STARTING_BUY
send(otherSide, 0, sessionID)
val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0]
val buyer = TwoPartyTradeProtocol.Buyer(
otherSide,
notary.identity,
amount,
CommercialPaper.State::class.java,
sessionID)
CommercialPaper.State::class.java)
// This invokes the trading protocol and out pops our finished transaction.
val tradeTX: SignedTransaction = subProtocol(buyer)
val tradeTX: SignedTransaction = subProtocol(buyer, inheritParentSessions = true)
// TODO: This should be moved into the protocol itself.
serviceHub.recordTransactions(listOf(tradeTX))
@ -306,7 +306,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
override fun call(): SignedTransaction {
progressTracker.currentStep = ANNOUNCING
val sessionID = sendAndReceive<Long>(otherSide, 0, 0, serviceHub.storageService.myLegalIdentity).unwrap { it }
send(otherSide, TraderDemoHandshake(serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = SELF_ISSUING
@ -316,9 +316,14 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
progressTracker.currentStep = TRADING
val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, commercialPaper, amount, cpOwnerKey,
sessionID, progressTracker.getChildProgressTracker(TRADING)!!)
val tradeTX: SignedTransaction = subProtocol(seller)
val seller = TwoPartyTradeProtocol.Seller(
otherSide,
notary,
commercialPaper,
amount,
cpOwnerKey,
progressTracker.getChildProgressTracker(TRADING)!!)
val tradeTX: SignedTransaction = subProtocol(seller, inheritParentSessions = true)
serviceHub.recordTransactions(listOf(tradeTX))
return tradeTX

View File

@ -75,7 +75,10 @@ object NodeInterestRates {
* Interest rates become available when they are uploaded via the web as per [DataUploadServlet],
* if they haven't already been uploaded that way.
*/
services.startProtocol("fixing", FixQueryHandler(this, req as RatesFixProtocol.QueryRequest))
req as RatesFixProtocol.QueryRequest
val handler = FixQueryHandler(this, req)
handler.registerSession(req)
services.startProtocol("fixing", handler)
Unit
}
},
@ -102,7 +105,7 @@ object NodeInterestRates {
override fun call(): Unit {
val answers = service.oracle.query(request.queries, request.deadline)
progressTracker.currentStep = SENDING
send(request.replyToParty, request.sessionID, answers)
send(request.replyToParty, answers)
}
}

View File

@ -2,18 +2,19 @@ package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
import com.r3corda.core.contracts.DealState
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.protocols.TwoPartyDealProtocol.DEAL_TOPIC
/**
* This whole class is really part of a demo just to initiate the agreement of a deal with a simple
@ -25,23 +26,24 @@ import com.r3corda.protocols.TwoPartyDealProtocol
object AutoOfferProtocol {
val TOPIC = "autooffer.topic"
data class AutoOfferMessage(val otherSide: Party,
val notary: Party,
val otherSessionID: Long, val dealBeingOffered: DealState)
data class AutoOfferMessage(val notary: Party,
val dealBeingOffered: DealState,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) {
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
object RECEIVED : ProgressTracker.Step("Received offer")
object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker()
}
fun tracker() = ProgressTracker(RECEIVED, DEALING)
fun tracker() = ProgressTracker(DEALING)
class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback<SignedTransaction> {
override fun onFailure(t: Throwable?) {
@ -54,20 +56,17 @@ object AutoOfferProtocol {
}
init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration ->
addProtocolHandler(TOPIC, "$DEAL_TOPIC.seller") { autoOfferMessage: AutoOfferMessage ->
val progressTracker = tracker()
progressTracker.currentStep = RECEIVED
val autoOfferMessage = msg.data.deserialize<AutoOfferMessage>()
// Put the deal onto the ledger
progressTracker.currentStep = DEALING
val seller = TwoPartyDealProtocol.Instigator(autoOfferMessage.otherSide, autoOfferMessage.notary,
autoOfferMessage.dealBeingOffered, services.keyManagementService.freshKey(), autoOfferMessage.otherSessionID, progressTracker.getChildProgressTracker(DEALING)!!)
val future = services.startProtocol("${TwoPartyDealProtocol.DEAL_TOPIC}.seller", seller)
// This is required because we are doing child progress outside of a subprotocol. In future, we should just wrap things like this in a protocol to avoid it
Futures.addCallback(future, Callback() {
seller.progressTracker.currentStep = ProgressTracker.DONE
progressTracker.currentStep = ProgressTracker.DONE
})
TwoPartyDealProtocol.Instigator(
autoOfferMessage.replyToParty,
autoOfferMessage.notary,
autoOfferMessage.dealBeingOffered,
services.keyManagementService.freshKey(),
progressTracker.getChildProgressTracker(DEALING)!!
)
}
}
@ -98,15 +97,15 @@ object AutoOfferProtocol {
@Suspendable
override fun call(): SignedTransaction {
require(serviceHub.networkMapCache.notaryNodes.isNotEmpty()) { "No notary nodes registered" }
val ourSessionID = random63BitValue()
val notary = serviceHub.networkMapCache.notaryNodes.first().identity
// need to pick which ever party is not us
val otherParty = notUs(dealToBeOffered.parties).single()
progressTracker.currentStep = ANNOUNCING
send(otherParty, 0, AutoOfferMessage(serviceHub.storageService.myLegalIdentity, notary, ourSessionID, dealToBeOffered))
send(otherParty, AutoOfferMessage(notary, dealToBeOffered, serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = DEALING
val stx = subProtocol(TwoPartyDealProtocol.Acceptor(otherParty, notary, dealToBeOffered, ourSessionID, progressTracker.getChildProgressTracker(DEALING)!!))
val stx = subProtocol(
Acceptor(otherParty, notary, dealToBeOffered, progressTracker.getChildProgressTracker(DEALING)!!),
inheritParentSessions = true)
return stx
}

View File

@ -2,12 +2,15 @@ package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache
import java.util.concurrent.TimeUnit
@ -19,7 +22,12 @@ object ExitServerProtocol {
// Will only be enabled if you install the Handler
@Volatile private var enabled = false
data class ExitMessage(val exitCode: Int)
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class ExitMessage(val exitCode: Int,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
@ -50,10 +58,8 @@ object ExitServerProtocol {
@Suspendable
override fun call(): Boolean {
if (enabled) {
val message = ExitMessage(exitCode)
for (recipient in serviceHub.networkMapCache.partyNodes) {
doNextRecipient(recipient, message)
doNextRecipient(recipient)
}
// Sleep a little in case any async message delivery to other nodes needs to happen
Strand.sleep(1, TimeUnit.SECONDS)
@ -63,11 +69,11 @@ object ExitServerProtocol {
}
@Suspendable
private fun doNextRecipient(recipient: NodeInfo, message: ExitMessage) {
private fun doNextRecipient(recipient: NodeInfo) {
if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore
} else {
send(recipient.identity, 0, message)
send(recipient.identity, ExitMessage(exitCode, recipient.identity))
}
}
}

View File

@ -1,15 +1,17 @@
package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.demos.DemoClock
import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache
import java.time.LocalDate
@ -20,7 +22,12 @@ object UpdateBusinessDayProtocol {
val TOPIC = "businessday.topic"
data class UpdateBusinessDayMessage(val date: LocalDate)
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class UpdateBusinessDayMessage(val date: LocalDate,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
@ -50,18 +57,17 @@ object UpdateBusinessDayProtocol {
@Suspendable
override fun call(): Unit {
progressTracker.currentStep = NOTIFYING
val message = UpdateBusinessDayMessage(date)
for (recipient in serviceHub.networkMapCache.partyNodes) {
doNextRecipient(recipient, message)
doNextRecipient(recipient)
}
}
@Suspendable
private fun doNextRecipient(recipient: NodeInfo, message: UpdateBusinessDayMessage) {
private fun doNextRecipient(recipient: NodeInfo) {
if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore
} else {
send(recipient.identity, 0, message)
send(recipient.identity, UpdateBusinessDayMessage(date, recipient.identity))
}
}
}

View File

@ -7,14 +7,14 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.contracts.InterestRateSwap
import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.failure
import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.random63BitValue
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockIdentityService
import java.security.KeyPair
@ -121,10 +121,9 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
irs.fixedLeg.fixedRatePayer = node1.info.identity
irs.floatingLeg.floatingRatePayer = node2.info.identity
val sessionID = random63BitValue()
val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, nodeAKey!!, sessionID)
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs, sessionID)
val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, nodeAKey!!)
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs)
connectProtocols(instigator, acceptor)
showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0]))

View File

@ -7,12 +7,13 @@ import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER
import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.contracts.DOLLARS
import com.r3corda.core.contracts.OwnableState
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.InMemoryMessagingNetwork
import java.time.Instant
@ -43,26 +44,24 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
seller.services.recordTransactions(issuance)
val amount = 1000.DOLLARS
val sessionID = random63BitValue()
val buyerProtocol = TwoPartyTradeProtocol.Buyer(
seller.info.identity,
notary.info.identity,
amount,
CommercialPaper.State::class.java,
sessionID)
CommercialPaper.State::class.java)
val sellerProtocol = TwoPartyTradeProtocol.Seller(
buyer.info.identity,
notary.info,
issuance.tx.outRef<OwnableState>(0),
amount,
seller.storage.myLegalIdentityKey,
sessionID)
seller.storage.myLegalIdentityKey)
connectProtocols(buyerProtocol, sellerProtocol)
showConsensusFor(listOf(buyer, seller, notary))
showProgressFor(listOf(buyer, seller))
val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.${TwoPartyTradeProtocol.TOPIC}.buyer", buyerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.${TwoPartyTradeProtocol.TOPIC}.seller", sellerProtocol)
val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.$TOPIC.buyer", buyerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.$TOPIC.seller", sellerProtocol)
return Futures.successfulAsList(buyerFuture, sellerFuture)
}

View File

@ -4,15 +4,19 @@ package com.r3corda.testing
import com.google.common.base.Throwables
import com.google.common.net.HostAndPort
import com.r3corda.testing.*
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.crypto.*
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub
import com.r3corda.testing.node.MockIdentityService
import com.r3corda.testing.node.MockServices
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockIdentityService
import com.r3corda.testing.node.MockServices
import java.net.ServerSocket
import java.security.KeyPair
import java.security.PublicKey
@ -124,3 +128,23 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
transactionBuilder: TransactionBuilder = TransactionBuilder(notary = DUMMY_NOTARY),
dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail
) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) }
/**
* Connect two protocols together for communication. Both protocols must have a property called otherParty of type Party
* which points to the other party in the communication.
*/
fun connectProtocols(protocol1: ProtocolLogic<*>, protocol2: ProtocolLogic<*>) {
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long,
override val receiveSessionID: Long) : HandshakeMessage
val sessionId1 = random63BitValue()
val sessionId2 = random63BitValue()
protocol1.registerSession(Handshake(protocol1.otherParty, sessionId1, sessionId2))
protocol2.registerSession(Handshake(protocol2.otherParty, sessionId2, sessionId1))
}
private val ProtocolLogic<*>.otherParty: Party
get() = javaClass.getDeclaredField("otherParty").apply { isAccessible = true }.get(this) as Party

View File

@ -2,12 +2,18 @@ package com.r3corda.testing.node
import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.KeyManagementService
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.WalletService
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.testing.InMemoryWalletService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor
@ -18,6 +24,7 @@ import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.ServiceRequestMessage
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import java.nio.file.Files
@ -130,6 +137,25 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// This does not indirect through the NodeInfo object so it can be called before the node is started.
// It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun send(topic: String, target: MockNode, payload: Any) {
services.networkService.send(TopicSession(topic), payload, target.info.address)
}
inline fun <reified T : Any> receive(topic: String, sessionId: Long): ListenableFuture<T> {
val receive = SettableFuture.create<T>()
services.networkService.runOnNextMessage(topic, sessionId) {
receive.set(it.data.deserialize<T>())
}
return receive
}
inline fun <reified T : Any> sendAndReceive(topic: String,
target: MockNode,
payload: ServiceRequestMessage): ListenableFuture<T> {
send(topic, target, payload)
return receive(topic, payload.sessionID)
}
}
/** Returns a node, optionally created by the passed factory method. */