Automatic session management between two protocols, and removal of explict topics

This commit is contained in:
Shams Asari 2016-09-27 18:25:26 +01:00
parent 4da73e28c7
commit 67fdf9b2ff
54 changed files with 1055 additions and 1113 deletions

2
.gitignore vendored
View File

@ -5,7 +5,6 @@
tags tags
.DS_Store .DS_Store
*.log *.log
*.log.gz
*.orig *.orig
# Created by .ignore support plugin (hsz.mobi) # Created by .ignore support plugin (hsz.mobi)
@ -100,3 +99,4 @@ crashlytics-build.properties
# docs related # docs related
docs/virtualenv/ docs/virtualenv/
/logs/

View File

@ -675,8 +675,8 @@ class InterestRateSwap() : Contract {
val nextFixingOf = nextFixingOf() ?: return null val nextFixingOf = nextFixingOf() ?: return null
// This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects // This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects
val (instant, duration) = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay) val instant = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay).start
return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef, duration), instant) return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef), instant)
} }
override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary) override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary)

View File

@ -46,21 +46,19 @@ import java.util.*
// and [AbstractStateReplacementProtocol]. // and [AbstractStateReplacementProtocol].
object TwoPartyTradeProtocol { object TwoPartyTradeProtocol {
val TOPIC = "platform.trade"
class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice") class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice")
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName" override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
} }
// This object is serialised to the network and is the first protocol message the seller sends to the buyer. // This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo( data class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>, val assetForSale: StateAndRef<OwnableState>,
val price: Amount<Currency>, val price: Amount<Currency>,
val sellerOwnerKey: PublicKey val sellerOwnerKey: PublicKey
) )
class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable) val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherParty: Party, open class Seller(val otherParty: Party,
@ -84,8 +82,6 @@ object TwoPartyTradeProtocol {
fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS) fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS)
} }
override val topic: String get() = TOPIC
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
val partialTX: SignedTransaction = receiveAndCheckProposedTransaction() val partialTX: SignedTransaction = receiveAndCheckProposedTransaction()
@ -172,7 +168,6 @@ object TwoPartyTradeProtocol {
object SWAPPING_SIGNATURES : ProgressTracker.Step("Swapping signatures with the seller") object SWAPPING_SIGNATURES : ProgressTracker.Step("Swapping signatures with the seller")
override val topic: String get() = TOPIC
override val progressTracker = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES) override val progressTracker = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES)
@Suspendable @Suspendable
@ -197,7 +192,7 @@ object TwoPartyTradeProtocol {
@Suspendable @Suspendable
private fun receiveAndValidateTradeRequest(): SellerTradeInfo { private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
progressTracker.currentStep = RECEIVING progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID. // Wait for a trade request to come in from the other side
val maybeTradeRequest = receive<SellerTradeInfo>(otherParty) val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
@ -243,8 +238,8 @@ object TwoPartyTradeProtocol {
private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> { private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> {
val ptx = TransactionType.General.Builder(notary) val ptx = TransactionType.General.Builder(notary)
// Add input and output states for the movement of cash, by using the Cash contract to generate the states. // Add input and output states for the movement of cash, by using the Cash contract to generate the states.
val wallet = serviceHub.vaultService.currentVault val vault = serviceHub.vaultService.currentVault
val cashStates = wallet.statesOfType<Cash.State>() val cashStates = vault.statesOfType<Cash.State>()
val cashSigningPubKeys = Cash().generateSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, cashStates) val cashSigningPubKeys = Cash().generateSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, cashStates)
// Add inputs/outputs/a command for the movement of the asset. // Add inputs/outputs/a command for the movement of the asset.
ptx.addInputState(tradeRequest.assetForSale) ptx.addInputState(tradeRequest.assetForSale)

View File

@ -1,5 +1,6 @@
package com.r3corda.core package com.r3corda.core
import com.google.common.base.Function
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.io.ByteStreams import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
@ -17,8 +18,8 @@ import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import java.time.Duration import java.time.Duration
import java.time.temporal.Temporal import java.time.temporal.Temporal
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.Future
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
@ -66,19 +67,20 @@ fun <T> ListenableFuture<T>.success(executor: Executor, body: (T) -> Unit) = the
fun <T> ListenableFuture<T>.failure(executor: Executor, body: (Throwable) -> Unit) = then(executor) { fun <T> ListenableFuture<T>.failure(executor: Executor, body: (Throwable) -> Unit) = then(executor) {
try { try {
get() get()
} catch(e: Throwable) { } catch (e: ExecutionException) {
body(e) body(e.cause!!)
} catch (t: Throwable) {
body(t)
} }
} }
infix fun <F, T> Future<F>.map(mapper: (F) -> T): Future<T> = Futures.lazyTransform(this) { mapper(it!!) }
infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) } infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) }
infix fun <F, T> ListenableFuture<F>.map(mapper: (F) -> T): ListenableFuture<T> = Futures.transform(this, Function { mapper(it!!) })
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block) infix fun <F, T> ListenableFuture<F>.flatMap(mapper: (F) -> ListenableFuture<T>): ListenableFuture<T> = Futures.transformAsync(this) { mapper(it!!) }
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */ /** Executes the given block and sets the future to either the result, or any exception that was thrown. */
// TODO This is not used but there's existing code that can be replaced by this
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> { fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
try { try {
set(block()) set(block())
@ -89,6 +91,8 @@ fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): Setta
return this return this
} }
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block)
// Simple infix function to add back null safety that the JDK lacks: timeA until timeB // Simple infix function to add back null safety that the JDK lacks: timeA until timeB
infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive) infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive)

View File

@ -2,19 +2,11 @@ package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.debug
import com.r3corda.protocols.HandshakeMessage
import org.slf4j.Logger import org.slf4j.Logger
import rx.Observable import rx.Observable
import java.util.*
import java.util.concurrent.CompletableFuture
/** /**
* A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you * A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you
@ -48,23 +40,14 @@ abstract class ProtocolLogic<out T> {
*/ */
val serviceHub: ServiceHub get() = psm.serviceHub val serviceHub: ServiceHub get() = psm.serviceHub
/** private var sessionProtocol: ProtocolLogic<*> = this
* The topic to use when communicating with other parties. If more than one topic is required then use sub-protocols.
* Note that this is temporary until protocol sessions are properly implemented.
*/
protected abstract val topic: String
private val sessions = HashMap<Party, Session>()
/** /**
* If a node receives a [HandshakeMessage] it needs to call this method on the initiated receipt protocol to enable * Return the marker [Class] which [party] has used to register the counterparty protocol that is to execute on the
* communication between it and the sender protocol. Calling this method, and other initiation steps, are already * other side. The default implementation returns the class object of this ProtocolLogic, but any [Class] instance
* handled by AbstractNodeService.addProtocolHandler. * will do as long as the other side registers with it.
*/ */
fun registerSession(receivedHandshake: HandshakeMessage) { open fun getCounterpartyMarker(party: Party): Class<*> = javaClass
// Note that the send and receive session IDs are swapped
addSession(receivedHandshake.replyToParty, receivedHandshake.receiveSessionID, receivedHandshake.sendSessionID)
}
// Kotlin helpers that allow the use of generic types. // Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<T> { inline fun <reified T : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<T> {
@ -73,69 +56,41 @@ abstract class ProtocolLogic<out T> {
@Suspendable @Suspendable
fun <T : Any> sendAndReceive(otherParty: Party, payload: Any, receiveType: Class<T>): UntrustworthyData<T> { fun <T : Any> sendAndReceive(otherParty: Party, payload: Any, receiveType: Class<T>): UntrustworthyData<T> {
val sendSessionId = getSendSessionId(otherParty, payload) return psm.sendAndReceive(otherParty, payload, receiveType, sessionProtocol)
val receiveSessionId = getReceiveSessionId(otherParty)
return psm.sendAndReceive(topic, otherParty, sendSessionId, receiveSessionId, payload, receiveType)
} }
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(otherParty, T::class.java) inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(otherParty, T::class.java)
@Suspendable @Suspendable
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>): UntrustworthyData<T> { fun <T : Any> receive(otherParty: Party, receiveType: Class<T>): UntrustworthyData<T> {
return psm.receive(topic, getReceiveSessionId(otherParty), receiveType) return psm.receive(otherParty, receiveType, sessionProtocol)
} }
@Suspendable @Suspendable
fun send(otherParty: Party, payload: Any) { fun send(otherParty: Party, payload: Any) {
psm.send(topic, otherParty, getSendSessionId(otherParty, payload), payload) psm.send(otherParty, payload, sessionProtocol)
} }
private fun addSession(party: Party, sendSesssionId: Long, receiveSessionId: Long) {
if (party in sessions) {
logger.debug { "Existing session with party $party to be overwritten by new one" }
}
sessions[party] = Session(sendSesssionId, receiveSessionId)
}
private fun getSendSessionId(otherParty: Party, payload: Any): Long {
return if (payload is HandshakeMessage) {
addSession(otherParty, payload.sendSessionID, payload.receiveSessionID)
DEFAULT_SESSION_ID
} else {
sessions[otherParty]?.sendSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
}
private fun getReceiveSessionId(otherParty: Party): Long {
return sessions[otherParty]?.receiveSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
/**
* Check if we already have a session with this party
*/
protected fun hasSession(otherParty: Party) = sessions.containsKey(otherParty)
/** /**
* Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the * Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the
* [ProtocolStateMachine] and then calling the [call] method. * [ProtocolStateMachine] and then calling the [call] method.
* @param inheritParentSessions In certain situations the subprotocol needs to inherit and use the same open * @param shareParentSessions In certain situations the need arises to use the same sessions the parent protocol has
* sessions of the parent. However in most cases this is not desirable as it prevents the subprotocol from * already established. However this also prevents the subprotocol from creating new sessions with those parties.
* communicating with the same party on a different topic. For this reason the default value is false. * For this reason the default value is false.
*/ */
@JvmOverloads // TODO Rethink the default value for shareParentSessions
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
@Suspendable @Suspendable
fun <R> subProtocol(subLogic: ProtocolLogic<R>, inheritParentSessions: Boolean = false): R { fun <R> subProtocol(subLogic: ProtocolLogic<R>, shareParentSessions: Boolean = false): R {
subLogic.psm = psm subLogic.psm = psm
if (inheritParentSessions) {
subLogic.sessions.putAll(sessions)
}
maybeWireUpProgressTracking(subLogic) maybeWireUpProgressTracking(subLogic)
val r = subLogic.call() if (shareParentSessions) {
subLogic.sessionProtocol = this
}
val result = subLogic.call()
// It's easy to forget this when writing protocols so we just step it to the DONE state when it completes. // It's easy to forget this when writing protocols so we just step it to the DONE state when it completes.
subLogic.progressTracker?.currentStep = ProgressTracker.DONE subLogic.progressTracker?.currentStep = ProgressTracker.DONE
return r return result
} }
private fun maybeWireUpProgressTracking(subLogic: ProtocolLogic<*>) { private fun maybeWireUpProgressTracking(subLogic: ProtocolLogic<*>) {
@ -166,12 +121,11 @@ abstract class ProtocolLogic<out T> {
@Suspendable @Suspendable
abstract fun call(): T abstract fun call(): T
private data class Session(val sendSessionId: Long, val receiveSessionId: Long)
// TODO this is not threadsafe, needs an atomic get-step-and-subscribe // TODO this is not threadsafe, needs an atomic get-step-and-subscribe
fun track(): Pair<String, Observable<String>>? { fun track(): Pair<String, Observable<String>>? {
return progressTracker?.let { return progressTracker?.let {
Pair(it.currentStep.toString(), it.changes.map { it.toString() }) Pair(it.currentStep.toString(), it.changes.map { it.toString() })
} }
} }
} }

View File

@ -1,7 +1,6 @@
package com.r3corda.core.protocols package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
@ -10,9 +9,12 @@ import org.slf4j.Logger
import java.util.* import java.util.*
data class StateMachineRunId private constructor(val uuid: UUID) { data class StateMachineRunId private constructor(val uuid: UUID) {
companion object { companion object {
fun createRandom(): StateMachineRunId = StateMachineRunId(UUID.randomUUID()) fun createRandom(): StateMachineRunId = StateMachineRunId(UUID.randomUUID())
} }
override fun toString(): String = "${javaClass.simpleName}($uuid)"
} }
/** /**
@ -20,18 +22,16 @@ data class StateMachineRunId private constructor(val uuid: UUID) {
*/ */
interface ProtocolStateMachine<R> { interface ProtocolStateMachine<R> {
@Suspendable @Suspendable
fun <T : Any> sendAndReceive(topic: String, fun <T : Any> sendAndReceive(otherParty: Party,
destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
payload: Any, payload: Any,
receiveType: Class<T>): UntrustworthyData<T> receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T>
@Suspendable @Suspendable
fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> fun <T : Any> receive(otherParty: Party, receiveType: Class<T>, sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T>
@Suspendable @Suspendable
fun send(topic: String, destination: Party, sessionID: Long, payload: Any) fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>)
val serviceHub: ServiceHub val serviceHub: ServiceHub
val logger: Logger val logger: Logger
@ -41,3 +41,5 @@ interface ProtocolStateMachine<R> {
/** This future will complete when the call method returns. */ /** This future will complete when the call method returns. */
val resultFuture: ListenableFuture<R> val resultFuture: ListenableFuture<R>
} }
class ProtocolSessionException(message: String) : Exception(message)

View File

@ -9,7 +9,6 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
@ -35,10 +34,6 @@ abstract class AbstractStateReplacementProtocol<T> {
val stx: SignedTransaction val stx: SignedTransaction
} }
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>, abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>,
val modification: T, val modification: T,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<StateAndRef<S>>() { override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<StateAndRef<S>>() {
@ -94,11 +89,6 @@ abstract class AbstractStateReplacementProtocol<T> {
private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey { private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey {
val proposal = assembleProposal(originalState.ref, modification, stx) val proposal = assembleProposal(originalState.ref, modification, stx)
// TODO: Move this into protocol logic as a func on the lines of handshake(Party, HandshakeMessage)
if (!hasSession(party)) {
send(party, Handshake(serviceHub.storageService.myLegalIdentity))
}
val response = sendAndReceive<Result>(party, proposal) val response = sendAndReceive<Result>(party, proposal)
val participantSignature = response.unwrap { val participantSignature = response.unwrap {
if (it.sig == null) throw StateReplacementException(it.error!!) if (it.sig == null) throw StateReplacementException(it.error!!)

View File

@ -5,7 +5,6 @@ import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
@ -25,31 +24,17 @@ import com.r3corda.core.transactions.SignedTransaction
class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction, class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
val events: Set<ClientToServiceCommand>, val events: Set<ClientToServiceCommand>,
val participants: Set<Party>) : ProtocolLogic<Unit>() { val participants: Set<Party>) : ProtocolLogic<Unit>() {
companion object {
/** Topic for messages notifying a node of a new transaction */
val TOPIC = "platform.wallet.notify_tx"
}
override val topic: String = TOPIC data class NotifyTxRequest(val tx: SignedTransaction, val events: Set<ClientToServiceCommand>)
data class NotifyTxRequestMessage(val tx: SignedTransaction,
val events: Set<ClientToServiceCommand>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
@Suspendable @Suspendable
override fun call() { override fun call() {
// Record it locally // Record it locally
serviceHub.recordTransactions(notarisedTransaction) serviceHub.recordTransactions(notarisedTransaction)
// TODO: Messaging layer should handle this broadcast for us (although we need to not be sending // TODO: Messaging layer should handle this broadcast for us
// session ID, for that to work, as well). val msg = NotifyTxRequest(notarisedTransaction, events)
participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant -> participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant ->
val msg = NotifyTxRequestMessage(
notarisedTransaction,
events,
serviceHub.storageService.myLegalIdentity)
send(participant, msg) send(participant, msg)
} }
} }

View File

@ -14,12 +14,6 @@ import java.io.InputStream
class FetchAttachmentsProtocol(requests: Set<SecureHash>, class FetchAttachmentsProtocol(requests: Set<SecureHash>,
otherSide: Party) : FetchDataProtocol<Attachment, ByteArray>(requests, otherSide) { otherSide: Party) : FetchDataProtocol<Attachment, ByteArray>(requests, otherSide) {
companion object {
const val TOPIC = "platform.fetch.attachment"
}
override val topic: String get() = TOPIC
override fun load(txid: SecureHash): Attachment? = serviceHub.storageService.attachments.openAttachment(txid) override fun load(txid: SecureHash): Attachment? = serviceHub.storageService.attachments.openAttachment(txid)
override fun convert(wire: ByteArray): Attachment { override fun convert(wire: ByteArray): Attachment {

View File

@ -5,7 +5,6 @@ import com.r3corda.core.contracts.NamedByHash
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch
import com.r3corda.protocols.FetchDataProtocol.HashNotFound import com.r3corda.protocols.FetchDataProtocol.HashNotFound
@ -21,8 +20,8 @@ import java.util.*
* [HashNotFound] exception being thrown. * [HashNotFound] exception being thrown.
* *
* By default this class does not insert data into any local database, if you want to do that after missing items were * By default this class does not insert data into any local database, if you want to do that after missing items were
* fetched then override [maybeWriteToDisk]. You *must* override [load] and [queryTopic]. If the wire type is not the * fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the
* same as the ultimate type, you must also override [convert]. * ultimate type, you must also override [convert].
* *
* @param T The ultimate type of the data being fetched. * @param T The ultimate type of the data being fetched.
* @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type. * @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type.
@ -35,10 +34,7 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
class HashNotFound(val requested: SecureHash) : BadAnswer() class HashNotFound(val requested: SecureHash) : BadAnswer()
class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer() class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer()
data class Request(val hashes: List<SecureHash>, data class Request(val hashes: List<SecureHash>)
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>) data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>)
@Suspendable @Suspendable
@ -51,9 +47,8 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
} else { } else {
logger.trace("Requesting ${toFetch.size} dependency(s) for verification") logger.trace("Requesting ${toFetch.size} dependency(s) for verification")
val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity)
// TODO: Support "large message" response streaming so response sizes are not limited by RAM. // TODO: Support "large message" response streaming so response sizes are not limited by RAM.
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, fetchReq) val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, Request(toFetch))
// Check for a buggy/malicious peer answering with something that we didn't ask for. // Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(maybeItems, toFetch) val downloaded = validateFetchResponse(maybeItems, toFetch)
maybeWriteToDisk(downloaded) maybeWriteToDisk(downloaded)

View File

@ -1,8 +1,8 @@
package com.r3corda.protocols package com.r3corda.protocols
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.transactions.SignedTransaction
/** /**
* Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them. * Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them.
@ -15,11 +15,5 @@ import com.r3corda.core.crypto.SecureHash
class FetchTransactionsProtocol(requests: Set<SecureHash>, otherSide: Party) : class FetchTransactionsProtocol(requests: Set<SecureHash>, otherSide: Party) :
FetchDataProtocol<SignedTransaction, SignedTransaction>(requests, otherSide) { FetchDataProtocol<SignedTransaction, SignedTransaction>(requests, otherSide) {
companion object {
const val TOPIC = "platform.fetch.tx"
}
override val topic: String get() = TOPIC
override fun load(txid: SecureHash): SignedTransaction? = serviceHub.storageService.validatedTransactions.getTransaction(txid) override fun load(txid: SecureHash): SignedTransaction? = serviceHub.storageService.validatedTransactions.getTransaction(txid)
} }

View File

@ -31,9 +31,6 @@ class FinalityProtocol(val transaction: SignedTransaction,
fun tracker() = ProgressTracker(NOTARISING, BROADCASTING) fun tracker() = ProgressTracker(NOTARISING, BROADCASTING)
} }
override val topic: String
get() = throw UnsupportedOperationException()
@Suspendable @Suspendable
override fun call() { override fun call() {
// TODO: Resolve the tx here: it's probably already been done, but re-resolution is a no-op and it'll make the API more forgiving. // TODO: Resolve the tx here: it's probably already been done, but re-resolution is a no-op and it'll make the API more forgiving.

View File

@ -24,8 +24,6 @@ import java.security.PublicKey
*/ */
object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() { object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
val TOPIC = "platform.notary.change"
data class Proposal(override val stateRef: StateRef, data class Proposal(override val stateRef: StateRef,
override val modification: Party, override val modification: Party,
override val stx: SignedTransaction) : AbstractStateReplacementProtocol.Proposal<Party> override val stx: SignedTransaction) : AbstractStateReplacementProtocol.Proposal<Party>
@ -35,8 +33,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
progressTracker: ProgressTracker = tracker()) progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Instigator<T, Party>(originalState, newNotary, progressTracker) { : AbstractStateReplacementProtocol.Instigator<T, Party>(originalState, newNotary, progressTracker) {
override val topic: String get() = TOPIC
override fun assembleProposal(stateRef: StateRef, modification: Party, stx: SignedTransaction): AbstractStateReplacementProtocol.Proposal<Party> override fun assembleProposal(stateRef: StateRef, modification: Party, stx: SignedTransaction): AbstractStateReplacementProtocol.Proposal<Party>
= Proposal(stateRef, modification, stx) = Proposal(stateRef, modification, stx)
@ -56,8 +52,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
override val progressTracker: ProgressTracker = tracker()) override val progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Acceptor<Party>(otherSide) { : AbstractStateReplacementProtocol.Acceptor<Party>(otherSide) {
override val topic: String get() = TOPIC
/** /**
* Check the notary change proposal. * Check the notary change proposal.
* *

View File

@ -5,12 +5,10 @@ import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SignedData import com.r3corda.core.crypto.SignedData
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessException import com.r3corda.core.node.services.UniquenessException
import com.r3corda.core.node.services.UniquenessProvider import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
@ -21,8 +19,6 @@ import java.security.PublicKey
object NotaryProtocol { object NotaryProtocol {
val TOPIC = "platform.notary"
/** /**
* A protocol to be used for obtaining a signature from a [NotaryService] ascertaining the transaction * A protocol to be used for obtaining a signature from a [NotaryService] ascertaining the transaction
* timestamp is correct and none of its inputs have been used in another completed transaction. * timestamp is correct and none of its inputs have been used in another completed transaction.
@ -30,7 +26,7 @@ object NotaryProtocol {
* @throws NotaryException in case the any of the inputs to the transaction have been consumed * @throws NotaryException in case the any of the inputs to the transaction have been consumed
* by another transaction or the timestamp is invalid. * by another transaction or the timestamp is invalid.
*/ */
class Client(private val stx: SignedTransaction, open class Client(private val stx: SignedTransaction,
override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() { override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
companion object { companion object {
@ -42,8 +38,6 @@ object NotaryProtocol {
fun tracker() = ProgressTracker(REQUESTING, VALIDATING) fun tracker() = ProgressTracker(REQUESTING, VALIDATING)
} }
override val topic: String get() = TOPIC
lateinit var notaryParty: Party lateinit var notaryParty: Party
@Suspendable @Suspendable
@ -51,9 +45,9 @@ object NotaryProtocol {
progressTracker.currentStep = REQUESTING progressTracker.currentStep = REQUESTING
val wtx = stx.tx val wtx = stx.tx
notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary")
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" } check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) {
"Input states must have the same Notary"
sendAndReceive<Ack>(notaryParty, Handshake(serviceHub.storageService.myLegalIdentity)) }
val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity) val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity)
val response = sendAndReceive<Result>(notaryParty, request) val response = sendAndReceive<Result>(notaryParty, request)
@ -80,6 +74,10 @@ object NotaryProtocol {
} }
} }
class ValidatingClient(stx: SignedTransaction) : Client(stx)
/** /**
* Checks that the timestamp command is valid (if present) and commits the input state, or returns a conflict * Checks that the timestamp command is valid (if present) and commits the input state, or returns a conflict
* if any of the input states have been previously committed. * if any of the input states have been previously committed.
@ -92,11 +90,9 @@ object NotaryProtocol {
val timestampChecker: TimestampChecker, val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() { val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() {
override val topic: String get() = TOPIC
@Suspendable @Suspendable
override fun call() { override fun call() {
val (stx, reqIdentity) = sendAndReceive<SignRequest>(otherSide, Ack).unwrap { it } val (stx, reqIdentity) = receive<SignRequest>(otherSide).unwrap { it }
val wtx = stx.tx val wtx = stx.tx
val result = try { val result = try {
@ -148,10 +144,6 @@ object NotaryProtocol {
} }
} }
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/** TODO: The caller must authenticate instead of just specifying its identity */ /** TODO: The caller must authenticate instead of just specifying its identity */
data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party) data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party)
@ -162,23 +154,10 @@ object NotaryProtocol {
} }
} }
interface Factory {
fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service
}
object DefaultFactory : Factory {
override fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service {
return Service(otherSide, timestampChecker, uniquenessProvider)
}
}
} }
class NotaryException(val error: NotaryError) : Exception() { class NotaryException(val error: NotaryError) : Exception() {
override fun toString() = "${super.toString()}: Error response from Notary - ${error.toString()}" override fun toString() = "${super.toString()}: Error response from Notary - $error"
} }
sealed class NotaryError { sealed class NotaryError {

View File

@ -6,7 +6,6 @@ import com.r3corda.core.contracts.FixOf
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
@ -34,8 +33,6 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() { override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
companion object { companion object {
val TOPIC = "platform.rates.interest.fix"
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate") class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
object WORKING : ProgressTracker.Step("Working with data returned by oracle") object WORKING : ProgressTracker.Step("Working with data returned by oracle")
object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle") object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle")
@ -43,31 +40,22 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING) fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING)
} }
override val topic: String get() = TOPIC
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount") class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>, data class QueryRequest(val queries: List<FixOf>, val deadline: Instant)
val deadline: Instant, data class SignRequest(val tx: WireTransaction)
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class SignRequest(val tx: WireTransaction,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
@Suspendable @Suspendable
override fun call() { override fun call() {
progressTracker.currentStep = progressTracker.steps[1] progressTracker.currentStep = progressTracker.steps[1]
val fix = query() val fix = subProtocol(FixQueryProtocol(fixOf, oracle))
progressTracker.currentStep = WORKING progressTracker.currentStep = WORKING
checkFixIsNearExpected(fix) checkFixIsNearExpected(fix)
tx.addCommand(fix, oracle.owningKey) tx.addCommand(fix, oracle.owningKey)
beforeSigning(fix) beforeSigning(fix)
progressTracker.currentStep = SIGNING progressTracker.currentStep = SIGNING
tx.addSignatureUnchecked(sign()) val signature = subProtocol(FixSignProtocol(tx, oracle))
tx.addSignatureUnchecked(signature)
} }
/** /**
@ -86,25 +74,13 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
} }
} }
@Suspendable
private fun sign(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, req)
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
}
}
class FixQueryProtocol(val fixOf: FixOf, val oracle: Party) : ProtocolLogic<Fix>() {
@Suspendable @Suspendable
private fun query(): Fix { override fun call(): Fix {
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
val req = QueryRequest(listOf(fixOf), deadline, serviceHub.storageService.myLegalIdentity)
// TODO: add deadline to receive // TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, req) val resp = sendAndReceive<ArrayList<Fix>>(oracle, QueryRequest(listOf(fixOf), deadline))
return resp.unwrap { return resp.unwrap {
val fix = it.first() val fix = it.first()
@ -114,3 +90,20 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
} }
} }
} }
class FixSignProtocol(val tx: TransactionBuilder, val oracle: Party) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
@Suspendable
override fun call(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, SignRequest(wtx))
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
}
}
}
}

View File

@ -127,8 +127,6 @@ class ResolveTransactionsProtocol(private val txHashes: Set<SecureHash>,
return result return result
} }
override val topic: String get() = throw UnsupportedOperationException()
@Suspendable @Suspendable
private fun downloadDependencies(depsToCheck: Set<SecureHash>): Collection<SignedTransaction> { private fun downloadDependencies(depsToCheck: Set<SecureHash>): Collection<SignedTransaction> {
// Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth // Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth

View File

@ -33,18 +33,3 @@ interface PartyRequestMessage : ServiceRequestMessage {
return networkMapCache.partyNodes.single { it.identity == replyToParty }.address return networkMapCache.partyNodes.single { it.identity == replyToParty }.address
} }
} }
/**
* A Handshake message is sent to initiate communication between two protocol instances. It contains the two session IDs
* the two protocols will need to communicate.
* Note: This is a temperary interface and will be removed once the protocol session work is implemented.
*/
interface HandshakeMessage : PartyRequestMessage {
val sendSessionID: Long
val receiveSessionID: Long
@Deprecated("sessionID functions as receiveSessionID but it's recommended to use the later for clarity",
replaceWith = ReplaceWith("receiveSessionID"))
override val sessionID: Long get() = receiveSessionID
}

View File

@ -6,12 +6,10 @@ import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.crypto.toBase58String
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.TransactionBuilder
@ -22,7 +20,6 @@ import com.r3corda.core.utilities.trace
import java.math.BigDecimal import java.math.BigDecimal
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration
/** /**
* Classes for manipulating a two party deal or agreement. * Classes for manipulating a two party deal or agreement.
@ -36,10 +33,6 @@ import java.time.Duration
*/ */
object TwoPartyDealProtocol { object TwoPartyDealProtocol {
val DEAL_TOPIC = "platform.deal"
/** This topic exists purely for [FixingSessionInitiation] to be sent from [FixingRoleDecider] to [FixingSessionInitiationHandler] */
val FIX_INITIATE_TOPIC = "platform.fix.initiate"
class DealMismatchException(val expectedDeal: ContractState, val actualDeal: ContractState) : Exception() { class DealMismatchException(val expectedDeal: ContractState, val actualDeal: ContractState) : Exception() {
override fun toString() = "The submitted deal didn't match the expected: $expectedDeal vs $actualDeal" override fun toString() = "The submitted deal didn't match the expected: $expectedDeal vs $actualDeal"
} }
@ -53,13 +46,19 @@ object TwoPartyDealProtocol {
class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable) class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable)
/**
* [Primary] at the end sends the signed tx to all the regulator parties. This a seperate workflow which needs a
* sepearate session with the regulator. This interface is used to do that in [Primary.getCounterpartyMarker].
*/
interface MarkerForBogusRegulatorProtocol
/** /**
* Abstracted bilateral deal protocol participant that initiates communication/handshake. * Abstracted bilateral deal protocol participant that initiates communication/handshake.
* *
* There's a good chance we can push at least some of this logic down into core protocol logic * There's a good chance we can push at least some of this logic down into core protocol logic
* and helper methods etc. * and helper methods etc.
*/ */
abstract class Primary<out U>(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic<SignedTransaction>() { abstract class Primary(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic<SignedTransaction>() {
companion object { companion object {
object AWAITING_PROPOSAL : ProgressTracker.Step("Handshaking and awaiting transaction proposal") object AWAITING_PROPOSAL : ProgressTracker.Step("Handshaking and awaiting transaction proposal")
@ -73,13 +72,19 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS, RECORDING, COPYING_TO_REGULATOR) fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS, RECORDING, COPYING_TO_REGULATOR)
} }
override val topic: String get() = DEAL_TOPIC abstract val payload: Any
abstract val payload: U
abstract val notaryNode: NodeInfo abstract val notaryNode: NodeInfo
abstract val otherParty: Party abstract val otherParty: Party
abstract val myKeyPair: KeyPair abstract val myKeyPair: KeyPair
override fun getCounterpartyMarker(party: Party): Class<*> {
return if (serviceHub.networkMapCache.regulators.any { it.identity == party }) {
MarkerForBogusRegulatorProtocol::class.java
} else {
super.getCounterpartyMarker(party)
}
}
@Suspendable @Suspendable
fun getPartialTransaction(): UntrustworthyData<SignedTransaction> { fun getPartialTransaction(): UntrustworthyData<SignedTransaction> {
progressTracker.currentStep = AWAITING_PROPOSAL progressTracker.currentStep = AWAITING_PROPOSAL
@ -199,8 +204,6 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES, RECORDING) fun tracker() = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES, RECORDING)
} }
override val topic: String get() = DEAL_TOPIC
abstract val otherParty: Party abstract val otherParty: Party
@Suspendable @Suspendable
@ -234,9 +237,7 @@ object TwoPartyDealProtocol {
val handshake = receive<Handshake<U>>(otherParty) val handshake = receive<Handshake<U>>(otherParty)
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
handshake.unwrap { return handshake.unwrap { validateHandshake(it) }
return validateHandshake(it)
}
} }
@Suspendable @Suspendable
@ -263,47 +264,45 @@ object TwoPartyDealProtocol {
@Suspendable protected abstract fun assembleSharedTX(handshake: Handshake<U>): Pair<TransactionBuilder, List<PublicKey>> @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake<U>): Pair<TransactionBuilder, List<PublicKey>>
} }
data class AutoOffer(val notary: Party, val dealBeingOffered: DealState)
/** /**
* One side of the protocol for inserting a pre-agreed deal. * One side of the protocol for inserting a pre-agreed deal.
*/ */
open class Instigator<out T : DealState>(override val otherParty: Party, open class Instigator(override val otherParty: Party,
val notary: Party, override val payload: AutoOffer,
override val payload: T,
override val myKeyPair: KeyPair, override val myKeyPair: KeyPair,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<T>() { override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
override val notaryNode: NodeInfo get() = override val notaryNode: NodeInfo get() =
serviceHub.networkMapCache.notaryNodes.filter { it.identity == notary }.single() serviceHub.networkMapCache.notaryNodes.filter { it.identity == payload.notary }.single()
} }
/** /**
* One side of the protocol for inserting a pre-agreed deal. * One side of the protocol for inserting a pre-agreed deal.
*/ */
open class Acceptor<T : DealState>(override val otherParty: Party, open class Acceptor(override val otherParty: Party,
val notary: Party, override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<AutoOffer>() {
val dealToBuy: T,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<T>() {
override fun validateHandshake(handshake: Handshake<T>): Handshake<T> { override fun validateHandshake(handshake: Handshake<AutoOffer>): Handshake<AutoOffer> {
// What is the seller trying to sell us? // What is the seller trying to sell us?
val deal: T = handshake.payload val autoOffer = handshake.payload
logger.trace { "Got deal request for: ${handshake.payload.ref}" } val deal = autoOffer.dealBeingOffered
logger.trace { "Got deal request for: ${deal.ref}" }
check(dealToBuy == deal) return handshake.copy(payload = autoOffer.copy(dealBeingOffered = deal))
return handshake.copy(payload = deal)
} }
override fun assembleSharedTX(handshake: Handshake<T>): Pair<TransactionBuilder, List<PublicKey>> { override fun assembleSharedTX(handshake: Handshake<AutoOffer>): Pair<TransactionBuilder, List<PublicKey>> {
val ptx = handshake.payload.generateAgreement(notary) val deal = handshake.payload.dealBeingOffered
val ptx = deal.generateAgreement(handshake.payload.notary)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt // And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one. // to have one.
ptx.setTime(serviceHub.clock.instant(), 30.seconds) ptx.setTime(serviceHub.clock.instant(), 30.seconds)
return Pair(ptx, arrayListOf(handshake.payload.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey)) return Pair(ptx, arrayListOf(deal.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey))
} }
} }
/** /**
@ -314,16 +313,15 @@ object TwoPartyDealProtocol {
* who does what in the protocol. * who does what in the protocol.
*/ */
class Fixer(override val otherParty: Party, class Fixer(override val otherParty: Party,
val oracleType: ServiceType, override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<FixingSession>() {
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() {
private lateinit var txState: TransactionState<*> private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState private lateinit var deal: FixableDealState
override fun validateHandshake(handshake: Handshake<StateRef>): Handshake<StateRef> { override fun validateHandshake(handshake: Handshake<FixingSession>): Handshake<FixingSession> {
logger.trace { "Got fixing request for: ${handshake.payload}" } logger.trace { "Got fixing request for: ${handshake.payload}" }
txState = serviceHub.loadState(handshake.payload) txState = serviceHub.loadState(handshake.payload.ref)
deal = txState.data as FixableDealState deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it. // validate the party that initiated is the one on the deal and that the recipient corresponds with it.
@ -336,7 +334,7 @@ object TwoPartyDealProtocol {
} }
@Suspendable @Suspendable
override fun assembleSharedTX(handshake: Handshake<StateRef>): Pair<TransactionBuilder, List<PublicKey>> { override fun assembleSharedTX(handshake: Handshake<FixingSession>): Pair<TransactionBuilder, List<PublicKey>> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val fixOf = deal.nextFixingOf()!! val fixOf = deal.nextFixingOf()!!
@ -348,12 +346,12 @@ object TwoPartyDealProtocol {
val ptx = TransactionType.General.Builder(txState.notary) val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(oracleType).first() val oracle = serviceHub.networkMapCache.get(handshake.payload.oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) { val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable @Suspendable
override fun beforeSigning(fix: Fix) { override fun beforeSigning(fix: Fix) {
newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload), fix) newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload.ref), fix)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt // And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one. // to have one.
@ -373,12 +371,13 @@ object TwoPartyDealProtocol {
* does what in the protocol. * does what in the protocol.
*/ */
class Floater(override val otherParty: Party, class Floater(override val otherParty: Party,
override val payload: StateRef, override val payload: FixingSession,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<StateRef>() { override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty { internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
val state = serviceHub.loadState(payload) as TransactionState<FixableDealState> val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload) StateAndRef(state, payload.ref)
} }
override val myKeyPair: KeyPair get() { override val myKeyPair: KeyPair get() {
@ -393,23 +392,18 @@ object TwoPartyDealProtocol {
/** Used to set up the session between [Floater] and [Fixer] */ /** Used to set up the session between [Floater] and [Fixer] */
data class FixingSessionInitiation(val timeout: Duration, data class FixingSession(val ref: StateRef, val oracleType: ServiceType)
val oracleType: ServiceType,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/** /**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing. * This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
* *
* It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the * It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the
* Fixer role is chosen, then that will be initiated by the [FixingSessionInitiation] message sent from the other party and * Fixer role is chosen, then that will be initiated by the [FixingSession] message sent from the other party and
* handled by the [FixingSessionInitiationHandler]. * handled by the [FixingSessionInitiationHandler].
* *
* TODO: Replace [FixingSessionInitiation] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists. * TODO: Replace [FixingSession] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists.
*/ */
class FixingRoleDecider(val ref: StateRef, class FixingRoleDecider(val ref: StateRef,
val timeout: Duration,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() { override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object { companion object {
@ -418,8 +412,6 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(LOADING()) fun tracker() = ProgressTracker(LOADING())
} }
override val topic: String get() = FIX_INITIATE_TOPIC
@Suspendable @Suspendable
override fun call(): Unit { override fun call(): Unit {
progressTracker.nextStep() progressTracker.nextStep()
@ -427,17 +419,10 @@ object TwoPartyDealProtocol {
// TODO: this is not the eventual mechanism for identifying the parties // TODO: this is not the eventual mechanism for identifying the parties
val fixableDeal = (dealToFix.data as FixableDealState) val fixableDeal = (dealToFix.data as FixableDealState)
val sortedParties = fixableDeal.parties.sortedBy { it.name } val sortedParties = fixableDeal.parties.sortedBy { it.name }
val oracleType = fixableDeal.oracleType
if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) { if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) {
val initation = FixingSessionInitiation( val fixing = FixingSession(ref, fixableDeal.oracleType)
timeout, // Start the Floater which will then kick-off the Fixer
oracleType, subProtocol(Floater(sortedParties[1], fixing))
serviceHub.storageService.myLegalIdentity)
// Send initiation to other side to launch one side of the fixing protocol (the Fixer).
send(sortedParties[1], initation)
// Then start the other side of the fixing protocol.
subProtocol(Floater(sortedParties[1], ref), inheritParentSessions = true)
} }
} }
} }

View File

@ -1,10 +1,11 @@
package com.r3corda.core.protocols; package com.r3corda.core.protocols;
import org.junit.Test;
import org.jetbrains.annotations.*; import java.util.HashMap;
import org.junit.*; import java.util.HashSet;
import java.util.Map;
import java.util.*; import java.util.Set;
public class ProtocolLogicRefFromJavaTest { public class ProtocolLogicRefFromJavaTest {
@ -33,12 +34,6 @@ public class ProtocolLogicRefFromJavaTest {
public Void call() { public Void call() {
return null; return null;
} }
@NotNull
@Override
protected String getTopic() {
throw new UnsupportedOperationException();
}
} }
private static class JavaNoArgProtocolLogic extends ProtocolLogic<Void> { private static class JavaNoArgProtocolLogic extends ProtocolLogic<Void> {
@ -50,12 +45,6 @@ public class ProtocolLogicRefFromJavaTest {
public Void call() { public Void call() {
return null; return null;
} }
@NotNull
@Override
protected String getTopic() {
throw new UnsupportedOperationException();
}
} }
@Test @Test

View File

@ -10,28 +10,21 @@ import com.pholser.junit.quickcheck.runner.JUnitQuickcheck
import com.r3corda.contracts.testing.SignedTransactionGenerator import com.r3corda.contracts.testing.SignedTransactionGenerator
import com.r3corda.core.serialization.createKryo import com.r3corda.core.serialization.createKryo
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.testing.PartyGenerator import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
import com.r3corda.protocols.BroadcastTransactionProtocol
import org.junit.runner.RunWith import org.junit.runner.RunWith
import kotlin.test.assertEquals import kotlin.test.assertEquals
@RunWith(JUnitQuickcheck::class) @RunWith(JUnitQuickcheck::class)
class BroadcastTransactionProtocolTest { class BroadcastTransactionProtocolTest {
class NotifyTxRequestMessageGenerator : Generator<BroadcastTransactionProtocol.NotifyTxRequestMessage>(BroadcastTransactionProtocol.NotifyTxRequestMessage::class.java) { class NotifyTxRequestMessageGenerator : Generator<NotifyTxRequest>(NotifyTxRequest::class.java) {
override fun generate(random: SourceOfRandomness, status: GenerationStatus): BroadcastTransactionProtocol.NotifyTxRequestMessage { override fun generate(random: SourceOfRandomness, status: GenerationStatus): NotifyTxRequest {
return BroadcastTransactionProtocol.NotifyTxRequestMessage( return NotifyTxRequest(tx = SignedTransactionGenerator().generate(random, status), events = setOf())
tx = SignedTransactionGenerator().generate(random, status),
events = setOf(),
replyToParty = PartyGenerator().generate(random, status),
sendSessionID = random.nextLong(),
receiveSessionID = random.nextLong()
)
} }
} }
@Property @Property
fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: BroadcastTransactionProtocol.NotifyTxRequestMessage) { fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: NotifyTxRequest) {
val kryo = createKryo() val kryo = createKryo()
val serialized = message.serialize().bits val serialized = message.serialize().bits
val deserialized = kryo.readClassAndObject(Input(serialized)) val deserialized = kryo.readClassAndObject(Input(serialized))

View File

@ -23,18 +23,15 @@ class ProtocolLogicRefTest {
constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b")) constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b"))
override fun call() = Unit override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
} }
class KotlinNoArgProtocolLogic : ProtocolLogic<Unit>() { class KotlinNoArgProtocolLogic : ProtocolLogic<Unit>() {
override fun call() = Unit override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
} }
@Suppress("UNUSED_PARAMETER") // We will never use A or b @Suppress("UNUSED_PARAMETER") // We will never use A or b
class NotWhiteListedKotlinProtocolLogic(A: Int, b: String) : ProtocolLogic<Unit>() { class NotWhiteListedKotlinProtocolLogic(A: Int, b: String) : ProtocolLogic<Unit>() {
override fun call() = Unit override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
} }
lateinit var factory: ProtocolLogicRefFactory lateinit var factory: ProtocolLogicRefFactory

View File

@ -91,7 +91,7 @@ Our protocol has two parties (B and S for buyer and seller) and will proceed as
it lacks a signature from S authorising movement of the asset. it lacks a signature from S authorising movement of the asset.
3. S signs it and hands the now finalised ``SignedTransaction`` back to B. 3. S signs it and hands the now finalised ``SignedTransaction`` back to B.
You can find the implementation of this protocol in the file ``contracts/protocols/TwoPartyTradeProtocol.kt``. You can find the implementation of this protocol in the file ``contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt``.
Assuming no malicious termination, they both end the protocol being in posession of a valid, signed transaction that Assuming no malicious termination, they both end the protocol being in posession of a valid, signed transaction that
represents an atomic asset swap. represents an atomic asset swap.
@ -110,7 +110,6 @@ each side.
.. sourcecode:: kotlin .. sourcecode:: kotlin
object TwoPartyTradeProtocol { object TwoPartyTradeProtocol {
val TOPIC = "platform.trade"
class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice") class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice")
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
@ -118,21 +117,20 @@ each side.
} }
// This object is serialised to the network and is the first protocol message the seller sends to the buyer. // This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo( data class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>, val assetForSale: StateAndRef<OwnableState>,
val price: Amount, val price: Amount<Currency>,
val sellerOwnerKey: PublicKey, val sellerOwnerKey: PublicKey
val sessionID: Long
) )
class SignaturesFromSeller(val timestampAuthoritySig: DigitalSignature.WithKey, val sellerSig: DigitalSignature.WithKey) data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherSide: Party, open class Seller(val otherSide: Party,
val notaryNode: NodeInfo, val notaryNode: NodeInfo,
val assetToSell: StateAndRef<OwnableState>, val assetToSell: StateAndRef<OwnableState>,
val price: Amount<Currency>, val price: Amount<Currency>,
val myKeyPair: KeyPair, val myKeyPair: KeyPair,
val buyerSessionID: Long,
override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() { override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() {
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
@ -143,8 +141,7 @@ each side.
open class Buyer(val otherSide: Party, open class Buyer(val otherSide: Party,
val notary: Party, val notary: Party,
val acceptablePrice: Amount<Currency>, val acceptablePrice: Amount<Currency>,
val typeToBuy: Class<out OwnableState>, val typeToBuy: Class<out OwnableState>) : ProtocolLogic<SignedTransaction>() {
val sessionID: Long) : ProtocolLogic<SignedTransaction>() {
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
TODO() TODO()
@ -152,25 +149,17 @@ each side.
} }
} }
Let's unpack what this code does: This code defines several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes are
simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol.
- It defines a several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes
are simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol.
- It defines the "trade topic", which is just a string that namespaces this protocol. The prefix "platform." is reserved
by Corda, but you can define your own protocol namespaces using standard Java-style reverse DNS notation.
Going through the data needed to become a seller, we have: Going through the data needed to become a seller, we have:
- ``otherSide: SingleMessageRecipient`` - the network address of the node with which you are trading. - ``otherSide: Party`` - the party with which you are trading.
- ``notaryNode: NodeInfo`` - the entry in the network map for the chosen notary. See ":doc:`consensus`" for more - ``notaryNode: NodeInfo`` - the entry in the network map for the chosen notary. See ":doc:`consensus`" for more
information on notaries. information on notaries.
- ``assetToSell: StateAndRef<OwnableState>`` - a pointer to the ledger entry that represents the thing being sold. - ``assetToSell: StateAndRef<OwnableState>`` - a pointer to the ledger entry that represents the thing being sold.
- ``price: Amount<Currency>`` - the agreed on price that the asset is being sold for (without an issuer constraint). - ``price: Amount<Currency>`` - the agreed on price that the asset is being sold for (without an issuer constraint).
- ``myKeyPair: KeyPair`` - the key pair that controls the asset being sold. It will be used to sign the transaction. - ``myKeyPair: KeyPair`` - the key pair that controls the asset being sold. It will be used to sign the transaction.
- ``buyerSessionID: Long`` - a unique number that identifies this trade to the buyer. It is expected that the buyer
knows that the trade is going to take place and has sent you such a number already.
.. note:: Session IDs will be automatically handled in a future version of the framework.
And for the buyer: And for the buyer:
@ -178,7 +167,6 @@ And for the buyer:
a price less than or equal to this, then the trade will go ahead. a price less than or equal to this, then the trade will go ahead.
- ``typeToBuy: Class<out OwnableState>`` - the type of state that is being purchased. This is used to check that the - ``typeToBuy: Class<out OwnableState>`` - the type of state that is being purchased. This is used to check that the
sell side of the protocol isn't trying to sell us the wrong thing, whether by accident or on purpose. sell side of the protocol isn't trying to sell us the wrong thing, whether by accident or on purpose.
- ``sessionID: Long`` - the session ID that was handed to the seller in order to start the protocol.
Alright, so using this protocol shouldn't be too hard: in the simplest case we can just create a Buyer or Seller Alright, so using this protocol shouldn't be too hard: in the simplest case we can just create a Buyer or Seller
with the details of the trade, depending on who we are. We then have to start the protocol in some way. Just with the details of the trade, depending on who we are. We then have to start the protocol in some way. Just
@ -221,6 +209,27 @@ protocol are checked against a whitelist, which can be extended by apps themselv
The process of starting a protocol returns a ``ListenableFuture`` that you can use to either block waiting for The process of starting a protocol returns a ``ListenableFuture`` that you can use to either block waiting for
the result, or register a callback that will be invoked when the result is ready. the result, or register a callback that will be invoked when the result is ready.
In a two party protocol only one side is to be manually started using ``ServiceHub.invokeProtocolAsync``. The other side
has to be registered by its node to respond to the initiating protocol via ``ServiceHubInternal.registerProtocolInitiator``.
In our example it doesn't matter which protocol is the initiator and which is the initiated. For example, if we are to
take the seller as the initiator then we would register the buyer as such:
.. container:: codeset
.. sourcecode:: kotlin
val services: ServiceHubInternal = TODO()
services.registerProtocolInitiator(Seller::class) { otherParty ->
val notary = services.networkMapCache.notaryNodes[0]
val acceptablePrice = TODO()
val typeToBuy = TODO()
Buyer(otherParty, notary, acceptablePrice, typeToBuy)
}
This is telling the buyer node to fire up an instance of ``Buyer`` (the code in the lambda) when the initiating protocol
is a seller (``Seller::class``).
Implementing the seller Implementing the seller
----------------------- -----------------------
@ -253,12 +262,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method.
@Suspendable @Suspendable
private fun receiveAndCheckProposedTransaction(): SignedTransaction { private fun receiveAndCheckProposedTransaction(): SignedTransaction {
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol. // Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID) val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, buyerSessionID, sessionID, hello) val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, hello)
maybeSTX.unwrap { maybeSTX.unwrap {
// Check that the tx proposed by the buyer is valid. // Check that the tx proposed by the buyer is valid.
@ -281,11 +288,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method.
} }
} }
Let's break this down. We generate a session ID to identify what's happening on the seller side, fill out Let's break this down. We fill out the initial protocol message with the trade info, and then call ``sendAndReceive``.
the initial protocol message, and then call ``sendAndReceive``. This function takes a few arguments: This function takes a few arguments:
- The topic string that ensures the message is routed to the right bit of code in the other side's node. - The party on the other side.
- The session IDs that ensure the messages don't get mixed up with other simultaneous trades.
- The thing to send. It'll be serialised and sent automatically. - The thing to send. It'll be serialised and sent automatically.
- Finally a type argument, which is the kind of object we're expecting to receive from the other side. If we get - Finally a type argument, which is the kind of object we're expecting to receive from the other side. If we get
back something else an exception is thrown. back something else an exception is thrown.
@ -370,7 +376,7 @@ Here's the rest of the code:
notarySignature: DigitalSignature.LegallyIdentifiable): SignedTransaction { notarySignature: DigitalSignature.LegallyIdentifiable): SignedTransaction {
val fullySigned = partialTX + ourSignature + notarySignature val fullySigned = partialTX + ourSignature + notarySignature
logger.trace { "Built finished transaction, sending back to secondary!" } logger.trace { "Built finished transaction, sending back to secondary!" }
send(otherSide, buyerSessionID, SignaturesFromSeller(ourSignature, notarySignature)) send(otherSide, SignaturesFromSeller(ourSignature, notarySignature))
return fullySigned return fullySigned
} }
@ -406,7 +412,7 @@ OK, let's do the same for the buyer side:
val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest) val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest)
val stx = signWithOurKeys(cashSigningPubKeys, ptx) val stx = signWithOurKeys(cashSigningPubKeys, ptx)
val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID) val signatures = swapSignaturesWithSeller(stx)
logger.trace { "Got signatures from seller, verifying ... " } logger.trace { "Got signatures from seller, verifying ... " }
@ -419,16 +425,14 @@ OK, let's do the same for the buyer side:
@Suspendable @Suspendable
private fun receiveAndValidateTradeRequest(): SellerTradeInfo { private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
// Wait for a trade request to come in on our pre-provided session ID. // Wait for a trade request to come in from the other side
val maybeTradeRequest = receive<SellerTradeInfo>(sessionID) val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
maybeTradeRequest.unwrap { maybeTradeRequest.unwrap {
// What is the seller trying to sell us? // What is the seller trying to sell us?
val asset = it.assetForSale.state.data val asset = it.assetForSale.state.data
val assetTypeName = asset.javaClass.name val assetTypeName = asset.javaClass.name
logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" } logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" }
// Check the start message for acceptability.
check(it.sessionID > 0)
if (it.price > acceptablePrice) if (it.price > acceptablePrice)
throw UnacceptablePriceException(it.price) throw UnacceptablePriceException(it.price)
if (!typeToBuy.isInstance(asset)) if (!typeToBuy.isInstance(asset))
@ -443,13 +447,13 @@ OK, let's do the same for the buyer side:
} }
@Suspendable @Suspendable
private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller { private fun swapSignaturesWithSeller(stx: SignedTransaction): SignaturesFromSeller {
progressTracker.currentStep = SWAPPING_SIGNATURES progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to seller" } logger.trace { "Sending partially signed transaction to seller" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx. // TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromSeller>(otherSide, theirSessionID, sessionID, stx).unwrap { it } return sendAndReceive<SignaturesFromSeller>(otherSide, stx).unwrap { it }
} }
private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction { private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
@ -676,7 +680,6 @@ Future features
The protocol framework is a key part of the platform and will be extended in major ways in future. Here are some of The protocol framework is a key part of the platform and will be extended in major ways in future. Here are some of
the features we have planned: the features we have planned:
* Automatic session ID management
* Identity based addressing * Identity based addressing
* Exposing progress trackers to local (inside the firewall) clients using message queues and/or WebSockets * Exposing progress trackers to local (inside the firewall) clients using message queues and/or WebSockets
* Exception propagation and management, with a "protocol hospital" tool to manually provide solutions to unavoidable * Exception propagation and management, with a "protocol hospital" tool to manually provide solutions to unavoidable

View File

@ -24,6 +24,7 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.debug
import com.r3corda.node.api.APIServer import com.r3corda.node.api.APIServer
import com.r3corda.node.services.api.* import com.r3corda.node.services.api.*
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
@ -54,8 +55,10 @@ import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.time.Clock import java.time.Clock
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
/** /**
* A base node implementation that can be customised either for production (with real implementations that do real * A base node implementation that can be customised either for production (with real implementations that do real
@ -91,6 +94,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
protected val _servicesThatAcceptUploads = ArrayList<AcceptsFileUpload>() protected val _servicesThatAcceptUploads = ArrayList<AcceptsFileUpload>()
val servicesThatAcceptUploads: List<AcceptsFileUpload> = _servicesThatAcceptUploads val servicesThatAcceptUploads: List<AcceptsFileUpload> = _servicesThatAcceptUploads
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
val services = object : ServiceHubInternal() { val services = object : ServiceHubInternal() {
override val networkService: MessagingServiceInternal get() = net override val networkService: MessagingServiceInternal get() = net
override val networkMapCache: NetworkMapCache get() = netMapCache override val networkMapCache: NetworkMapCache get() = netMapCache
@ -109,6 +114,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
return smm.add(loggerName, logic).resultFuture return smm.add(loggerName, logic).resultFuture
} }
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
require(markerClass !in protocolFactories) { "${markerClass.java.name} has already been used to register a protocol" }
log.debug { "Registering ${markerClass.java.name}" }
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(storage, txs) override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(storage, txs)
} }

View File

@ -1,11 +1,9 @@
package com.r3corda.node.services package com.r3corda.node.services
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange { object NotaryChange {
class Plugin : CordaPluginRegistry() { class Plugin : CordaPluginRegistry() {
@ -16,11 +14,9 @@ object NotaryChange {
* A service that monitors the network for requests for changing the notary of a state, * A service that monitors the network for requests for changing the notary of a state,
* and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met. * and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met.
*/ */
class Service(services: ServiceHubInternal) : AbstractNodeService(services) { class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init { init {
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake -> services.registerProtocolInitiator(NotaryChangeProtocol.Instigator::class) { NotaryChangeProtocol.Acceptor(it) }
NotaryChangeProtocol.Acceptor(req.replyToParty)
}
} }
} }
} }

View File

@ -1,16 +1,12 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.MessageHandlerRegistration import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.messaging.createMessage import com.r3corda.core.messaging.createMessage
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
@ -20,10 +16,6 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe @ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() { abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService val net: MessagingServiceInternal get() = services.networkService
/** /**
@ -68,36 +60,4 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception }) return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
} }
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: ProtocolLogic<R>.(ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
protocol.onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
} }

View File

@ -1,7 +1,7 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.ProtocolIORequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
/** /**
@ -30,14 +30,13 @@ interface CheckpointStorage {
} }
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint( class Checkpoint(val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: ProtocolIORequest?, val id: SecureHash get() = serialisedFiber.hash
val receivedPayload: Any?
) { override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
// This flag is always false when loaded from storage as it isn't serialised.
// It is used to track when the associated fiber has been created, but not necessarily started when override fun hashCode(): Int = id.hashCode()
// messages for protocols arrive before the system has fully loaded at startup.
@Transient override fun toString(): String = "${javaClass.simpleName}(id=$id)"
var fiberCreated: Boolean = false
} }

View File

@ -1,14 +1,16 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService { interface MessagingServiceInternal : MessagingService {
/** /**
@ -49,7 +51,7 @@ abstract class ServiceHubInternal : ServiceHub {
* @param txs The transactions to record. * @param txs The transactions to record.
*/ */
internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable<SignedTransaction>) { internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable<SignedTransaction>) {
val stateMachineRunId = ProtocolStateMachineImpl.retrieveCurrentStateMachine()?.id val stateMachineRunId = ProtocolStateMachineImpl.currentStateMachine()?.id
if (stateMachineRunId != null) { if (stateMachineRunId != null) {
txs.forEach { txs.forEach {
storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id) storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id)
@ -68,6 +70,23 @@ abstract class ServiceHubInternal : ServiceHub {
*/ */
abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T>
/**
* Register the protocol factory we wish to use when a initiating party attempts to communicate with us. The
* registration is done against a marker [KClass] which is sent in the session handsake by the other party. If this
* marker class has been registered then the corresponding factory will be used to create the protocol which will
* communicate with the other side. If there is no mapping then the session attempt is rejected.
* @param markerClass The marker [KClass] present in a session initiation attempt, which is a 1:1 mapping to a [Class]
* using the <pre>::class</pre> construct. Any marker class can be used, with the default being the class of the initiating
* protocol. This enables the registration to be of the form: registerProtocolInitiator(InitiatorProtocol::class, ::InitiatedProtocol)
* @param protocolFactory The protocol factory generating the initiated protocol.
*/
abstract fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>)
/**
* Return the protocol factory that has been registered with [markerClass], or null if no factory is found.
*/
abstract fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)?
override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> { override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> {
val logicRef = protocolLogicRefFactory.create(logicType, *args) val logicRef = protocolLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")

View File

@ -1,28 +1,25 @@
package com.r3corda.node.services.clientapi package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC import com.r3corda.protocols.TwoPartyDealProtocol.Floater
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
/** /**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of * This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
* running scheduled fixings for the [InterestRateSwap] contract. * running scheduled fixings for the [InterestRateSwap] contract.
* *
* TODO: This will be replaced with the automatic sessionID / session setup work. * TODO: This will be replaced with the symmetric session work
*/ */
object FixingSessionInitiation { object FixingSessionInitiation {
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) : AbstractNodeService(services) { class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init { init {
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation -> services.registerProtocolInitiator(Floater::class) { Fixer(it) }
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
}
} }
} }
} }

View File

@ -169,7 +169,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false) val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient)) val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted( return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, smm.add("broadcast", protocol).id,
tx, tx,
"Cash payment transaction generated" "Cash payment transaction generated"
) )
@ -203,7 +203,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false) val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), participants) val protocol = FinalityProtocol(tx, setOf(req), participants)
return TransactionBuildResult.ProtocolStarted( return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, smm.add("broadcast", protocol).id,
tx, tx,
"Cash destruction transaction generated" "Cash destruction transaction generated"
) )
@ -222,7 +222,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it // Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient)) val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted( return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, smm.add("broadcast", protocol).id,
tx, tx,
"Cash issuance completed" "Cash issuance completed"
) )

View File

@ -1,17 +1,12 @@
package com.r3corda.node.services.persistence package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.failure
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.serialization.serialize import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.success import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.* import com.r3corda.protocols.*
import java.io.InputStream import java.io.InputStream
@ -39,71 +34,49 @@ object DataVending {
// TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because // TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because
// the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both // the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both
// should be fixed at the same time. // should be fixed at the same time.
class Service(services: ServiceHubInternal) : AbstractNodeService(services) { class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object { companion object {
val logger = loggerFor<DataVending.Service>() val logger = loggerFor<DataVending.Service>()
/**
* Notify a node of a transaction. Normally any notarisation required would happen before this is called.
*/
fun notify(net: MessagingService,
myIdentity: Party,
recipient: NodeInfo,
transaction: SignedTransaction) {
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
} }
}
val storage = services.storageService
class TransactionRejectedError(msg: String) : Exception(msg) class TransactionRejectedError(msg: String) : Exception(msg)
init { init {
addMessageHandler(FetchTransactionsProtocol.TOPIC, services.registerProtocolInitiator(FetchTransactionsProtocol::class, ::FetchTransactionsHandler)
{ req: FetchDataProtocol.Request -> handleTXRequest(req) }, services.registerProtocolInitiator(FetchAttachmentsProtocol::class, ::FetchAttachmentsHandler)
{ message, e -> logger.error("Failure processing data vending request.", e) } services.registerProtocolInitiator(BroadcastTransactionProtocol::class, ::NotifyTransactionHandler)
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
addProtocolHandler(
BroadcastTransactionProtocol.TOPIC,
"Resolving transactions",
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
ResolveTransactionsProtocol(req.tx, req.replyToParty)
},
{ future, req ->
future.success {
serviceHub.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
} }
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {
require(req.hashes.isNotEmpty()) private class FetchTransactionsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
return req.hashes.map { @Suspendable
val tx = storage.validatedTransactions.getTransaction(it) override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val txs = request.hashes.map {
val tx = serviceHub.storageService.validatedTransactions.getTransaction(it)
if (tx == null) if (tx == null)
logger.info("Got request for unknown tx $it") logger.info("Got request for unknown tx $it")
tx tx
} }
send(otherParty, txs)
}
} }
private fun handleAttachmentRequest(req: FetchDataProtocol.Request): List<ByteArray?> {
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer. // TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
require(req.hashes.isNotEmpty()) private class FetchAttachmentsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
return req.hashes.map { @Suspendable
val jar: InputStream? = storage.attachments.openAttachment(it)?.open() override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val attachments = request.hashes.map {
val jar: InputStream? = serviceHub.storageService.attachments.openAttachment(it)?.open()
if (jar == null) { if (jar == null) {
logger.info("Got request for unknown attachment $it") logger.info("Got request for unknown attachment $it")
null null
@ -111,6 +84,23 @@ object DataVending {
jar.readBytes() jar.readBytes()
} }
} }
send(otherParty, attachments)
}
}
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
class NotifyTransactionHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<BroadcastTransactionProtocol.NotifyTxRequest>(otherParty).unwrap { it }
subProtocol(ResolveTransactionsProtocol(request.tx, otherParty), shareParentSessions = true)
serviceHub.recordTransactions(request.tx)
} }
} }
} }
}

View File

@ -1,53 +1,38 @@
package com.r3corda.node.services.statemachine package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party import com.r3corda.node.services.statemachine.StateMachineManager.ProtocolSession
import com.r3corda.core.messaging.TopicSession import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import java.util.*
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes // TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
interface ProtocolIORequest { interface ProtocolIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we // This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber. // don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot val stackTraceInCaseOfProblems: StackSnapshot
val topic: String val session: ProtocolSession
} }
interface SendRequest : ProtocolIORequest { interface SendRequest : ProtocolIORequest {
val destination: Party val message: SessionMessage
val payload: Any
val sendSessionID: Long
val uniqueMessageId: UUID
} }
interface ReceiveRequest<T> : ProtocolIORequest { interface ReceiveRequest<T : SessionMessage> : ProtocolIORequest {
val receiveType: Class<T> val receiveType: Class<T>
val receiveSessionID: Long
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
} }
data class SendAndReceive<T>(override val topic: String, data class SendAndReceive<T : SessionMessage>(override val session: ProtocolSession,
override val destination: Party, override val message: SessionMessage,
override val payload: Any, override val receiveType: Class<T>) : SendRequest, ReceiveRequest<T> {
override val sendSessionID: Long,
override val uniqueMessageId: UUID,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
data class ReceiveOnly<T>(override val topic: String, data class ReceiveOnly<T : SessionMessage>(override val session: ProtocolSession,
override val receiveType: Class<T>, override val receiveType: Class<T>) : ReceiveRequest<T> {
override val receiveSessionID: Long) : ReceiveRequest<T> {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
data class SendOnly(override val destination: Party, data class SendOnly(override val session: ProtocolSession, override val message: SessionMessage) : SendRequest {
override val topic: String,
override val payload: Any,
override val sendSessionID: Long,
override val uniqueMessageId: UUID) : SendRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }

View File

@ -8,16 +8,22 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.rootCause
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.createDatabaseTransaction import com.r3corda.node.utilities.createDatabaseTransaction
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.PrintWriter
import java.io.StringWriter
import java.sql.SQLException import java.sql.SQLException
import java.util.* import java.util.*
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
@ -36,12 +42,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
private val loggerName: String) private val loggerName: String)
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> { : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
field.isAccessible = true
field.get(null)
}
/**
* Return the current [ProtocolStateMachineImpl] or null if executing outside of one.
*/
fun currentStateMachine(): ProtocolStateMachineImpl<*>? = Strand.currentStrand() as? ProtocolStateMachineImpl<*>
}
// These fields shouldn't be serialised, so they are marked @Transient. // These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit @Transient internal lateinit var actionOnSuspend: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null
@Transient internal lateinit var database: Database @Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var _logger: Logger? = null @Transient private var _logger: Logger? = null
override val logger: Logger get() { override val logger: Logger get() {
@ -62,18 +82,20 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
} }
} }
internal val openSessions = HashMap<Pair<ProtocolLogic<*>, Party>, ProtocolSession>()
init { init {
logic.psm = this logic.psm = this
name = id.toString()
} }
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable
override fun run(): R { override fun run(): R {
createTransaction() createTransaction()
val result = try { val result = try {
logic.call() logic.call()
} catch (t: Throwable) { } catch (t: Throwable) {
actionOnEnd() processException(t)
_resultFuture?.setException(t)
commitTransaction() commitTransaction()
throw ExecutionException(t) throw ExecutionException(t)
} }
@ -106,56 +128,140 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> { override fun <T : Any> sendAndReceive(otherParty: Party,
suspend(receiveRequest)
check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
receivedPayload = null
return untrustworthy
}
@Suspendable
override fun <T : Any> sendAndReceive(topic: String,
destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
payload: Any, payload: Any,
receiveType: Class<T>): UntrustworthyData<T> { receiveType: Class<T>,
return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, UUID.randomUUID(), receiveType, sessionIDForReceive)) sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
} }
@Suspendable @Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> { override fun <T : Any> receive(otherParty: Party,
return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive)) receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
} }
@Suspendable @Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
suspend(SendOnly(destination, topic, payload, sessionID, UUID.randomUUID())) val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
sendInternal(session, sendSessionData)
}
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
val otherPartySessionId = session.otherPartySessionId
?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
return SessionData(otherPartySessionId, payload)
} }
@Suspendable @Suspendable
private fun suspend(protocolIORequest: ProtocolIORequest) { private fun sendInternal(session: ProtocolSession, message: SessionMessage) {
suspend(SendOnly(session, message))
}
@Suspendable
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
}
@Suspendable
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
}
@Suspendable
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
}
@Suspendable
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
openSessions[Pair(sessionProtocol, otherParty)] = session
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, serviceHub.storageService.myLegalIdentity, counterpartyProtocol)
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
return session
} else {
sessionInitResponse as SessionReject
throw ProtocolSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
}
}
@Suspendable
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
val receivedMessage = getReceivedMessage() ?: run {
// Suspend while we wait for the receive
receiveRequest.session.waitingForResponse = true
suspend(receiveRequest)
receiveRequest.session.waitingForResponse = false
getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $id $receiveRequest")
}
if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage)
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $id $receiveRequest")
}
}
@Suspendable
private fun suspend(ioRequest: ProtocolIORequest) {
commitTransaction() commitTransaction()
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended $id on $ioRequest" }
try { try {
suspendAction(protocolIORequest) actionOnSuspend(ioRequest)
} catch (t: Throwable) { } catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it. // Do not throw exception again - Quasar completely bins it.
logger.warn("Captured exception which was swallowed by Quasar", t) logger.warn("Captured exception which was swallowed by Quasar", t)
actionOnEnd() // TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
_resultFuture?.setException(t) // completing the Future
processException(t)
} }
} }
createTransaction() createTransaction()
} }
companion object { private fun processException(t: Throwable) {
/** actionOnEnd()
* Retrieves our state machine id if we are running a [ProtocolStateMachineImpl]. _resultFuture?.setException(t)
*/ }
fun retrieveCurrentStateMachine(): ProtocolStateMachineImpl<*>? {
return Strand.currentStrand() as? ProtocolStateMachineImpl<*> internal fun resume(scheduler: FiberScheduler) {
try {
if (fromCheckpoint) {
logger.info("$id resumed from checkpoint")
fromCheckpoint = false
Fiber.unparkDeserialized(this, scheduler)
} else if (state == State.NEW) {
logger.trace { "$id started" }
start()
} else {
logger.trace { "$id resumed" }
Fiber.unpark(this, QUASAR_UNBLOCKER)
}
} catch (t: Throwable) {
logger.error("$id threw '${t.rootCause}'")
logger.trace {
val s = StringWriter()
t.rootCause.printStackTrace(PrintWriter(s))
"Stack trace of protocol error: $s"
} }
} }
} }
}

View File

@ -3,34 +3,38 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.io.serialization.kryo.KryoSerializer import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import co.paralleluniverse.strands.Strand
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.ThreadBox import com.r3corda.core.ThreadBox
import com.r3corda.core.abbreviate import com.r3corda.core.abbreviate
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.* import com.r3corda.core.serialization.*
import com.r3corda.core.then import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.debug
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import kotlinx.support.jdk8.collections.removeIf
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import java.io.PrintWriter
import java.io.StringWriter
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
@ -48,7 +52,6 @@ import javax.annotation.concurrent.ThreadSafe
* The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually * The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually
* starts them via [add]. * starts them via [add].
* *
* TODO: Session IDs should be set up and propagated automatically, on demand.
* TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised * TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised
* continuation is always unique? * continuation is always unique?
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk * TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
@ -58,12 +61,19 @@ import javax.annotation.concurrent.ThreadSafe
* TODO: Implement stub/skel classes that provide a basic RPC framework on top of this. * TODO: Implement stub/skel classes that provide a basic RPC framework on top of this.
*/ */
@ThreadSafe @ThreadSafe
class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableServices: List<Any>, class StateMachineManager(val serviceHub: ServiceHubInternal,
tokenizableServices: List<Any>,
val checkpointStorage: CheckpointStorage, val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor, val executor: AffinityExecutor,
val database: Database) { val database: Database) {
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
companion object {
private val logger = loggerFor<StateMachineManager>()
internal val sessionTopic = TopicSession("platform.session")
}
val scheduler = FiberScheduler() val scheduler = FiberScheduler()
data class Change( data class Change(
@ -95,6 +105,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private val totalStartedProtocols = metrics.counter("Protocols.Started") private val totalStartedProtocols = metrics.counter("Protocols.Started")
private val totalFinishedProtocols = metrics.counter("Protocols.Finished") private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
private val openSessions = ConcurrentHashMap<Long, ProtocolSession>()
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
// Context for tokenized services in checkpoints // Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo()) private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
@ -119,6 +132,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val changes: Observable<Change> val changes: Observable<Change>
get() = mutex.content.changesPublisher get() = mutex.content.changesPublisher
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
}
}
fun start() {
restoreFibersFromCheckpoints()
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
}
/** /**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines] * calls to [allStateMachines]
@ -131,69 +155,99 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
} }
} }
// Used to work around a small limitation in Quasar. private fun restoreFibersFromCheckpoints() {
private val QUASAR_UNBLOCKER = run { mutex.locked {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER") checkpointStorage.checkpoints.forEach {
field.isAccessible = true // If a protocol is added before start() then don't attempt to restore it
field.get(null) if (!stateMachines.containsValue(it)) {
val fiber = deserializeFiber(it.serialisedFiber)
initFiber(fiber)
stateMachines[fiber] = it
}
} }
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
} }
} }
fun start() { private fun resumeRestoredFibers() {
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
mutex.locked { mutex.locked {
started = true started = true
stateMachines.forEach { restartFiber(it.key, it.value) } stateMachines.keys.forEach { resumeRestoredFiber(it) }
}
serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg ->
executor.checkOnThread()
val sessionMessage = message.data.deserialize<SessionMessage>()
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage)
is SessionInit -> onSessionInit(sessionMessage)
} }
} }
} }
private fun createFiberForCheckpoint(checkpoint: Checkpoint) { private fun resumeRestoredFiber(fiber: ProtocolStateMachineImpl<*>) {
if (!checkpoint.fiberCreated) { fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
val fiber = deserializeFiber(checkpoint.serialisedFiber) if (fiber.openSessions.values.any { it.waitingForResponse }) {
initFiber(fiber, { checkpoint }) fiber.logger.info("Restored fiber pending on receive ${fiber.id}}")
} else {
resumeFiber(fiber)
} }
} }
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) { private fun onExistingSessionMessage(message: ExistingSessionMessage) {
if (checkpoint.request is ReceiveRequest<*>) { val session = openSessions[message.recipientSessionId]
val topicSession = checkpoint.request.receiveTopicSession if (session != null) {
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession") session.psm.logger.trace { "${session.psm.id} received $message on $session" }
iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) { if (message is SessionEnd) {
try { openSessions.remove(message.recipientSessionId)
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, topicSession, fiber)
} }
} session.receivedMessages += message
if (checkpoint.request is SendRequest) { if (session.waitingForResponse) {
sendMessage(fiber, checkpoint.request) updateCheckpoint(session.psm)
resumeFiber(session.psm)
} }
} else { } else {
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}") val otherParty = recentlyClosedSessions.remove(message.recipientSessionId)
executor.executeASAP { if (otherParty != null) {
if (checkpoint.request is SendRequest) { if (message is SessionConfirm) {
sendMessage(fiber, checkpoint.request) logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null)
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
} }
iterateStateMachine(fiber, checkpoint.receivedPayload) { } else {
logger.warn("Received a session message for unknown session: $message")
}
}
}
private fun onSessionInit(sessionInit: SessionInit) {
logger.trace { "Received $sessionInit" }
//TODO Verify the other party are who they say they are from the TLS subsystem
val otherParty = sessionInit.initiatorParty
val otherPartySessionId = sessionInit.initiatorSessionId
try { try {
Fiber.unparkDeserialized(fiber, scheduler) val markerClass = Class.forName(sessionInit.protocolName)
} catch (e: Throwable) { val protocolFactory = serviceHub.getProtocolFactory(markerClass)
logError(e, it, null, fiber) if (protocolFactory != null) {
} val protocol = protocolFactory(otherParty)
} val psm = createFiber(sessionInit.protocolName, protocol)
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(protocol, otherParty)] = session
updateCheckpoint(psm)
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
psm.logger.debug { "Starting new ${psm.id} from $sessionInit on $session" }
startFiber(psm)
} else {
logger.warn("Unknown protocol marker class in $sessionInit")
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
} }
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
} }
} }
private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes<ProtocolStateMachineImpl<*>> { private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes<ProtocolStateMachineImpl<*>> {
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
val kryo = quasarKryo() val kryo = quasarKryo()
// add the map of tokens -> tokenizedServices to the kyro context // add the map of tokens -> tokenizedServices to the kyro context
SerializeAsTokenSerializer.setContext(kryo, serializationContext) SerializeAsTokenSerializer.setContext(kryo, serializationContext)
@ -204,7 +258,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val kryo = quasarKryo() val kryo = quasarKryo()
// put the map of token -> tokenized into the kryo context // put the map of token -> tokenized into the kryo context
SerializeAsTokenSerializer.setContext(kryo, serializationContext) SerializeAsTokenSerializer.setContext(kryo, serializationContext)
return serialisedFiber.deserialize(kryo) return serialisedFiber.deserialize(kryo).apply { fromCheckpoint = true }
} }
private fun quasarKryo(): Kryo { private fun quasarKryo(): Kryo {
@ -212,70 +266,51 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
return createKryo(serializer.kryo) return createKryo(serializer.kryo)
} }
private fun logError(e: Throwable, payload: Any?, topicSession: TopicSession?, psm: ProtocolStateMachineImpl<*>) { private fun <T> createFiber(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachineImpl<T> {
psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " + val id = StateMachineRunId.createRandom()
"when handling a message of type ${payload?.javaClass?.name} on queue $topicSession") return ProtocolStateMachineImpl(id, logic, scheduler, loggerName).apply { initFiber(this) }
if (psm.logger.isTraceEnabled) {
val s = StringWriter()
Throwables.getRootCause(e).printStackTrace(PrintWriter(s))
psm.logger.trace("Stack trace of protocol error is: $s")
}
} }
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint { private fun initFiber(psm: ProtocolStateMachineImpl<*>) {
psm.database = database psm.database = database
psm.serviceHub = serviceHub psm.serviceHub = serviceHub
psm.suspendAction = { request -> psm.actionOnSuspend = { ioRequest ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" } updateCheckpoint(psm)
onNextSuspend(psm, request) processIORequest(ioRequest)
} }
psm.actionOnEnd = { psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked { mutex.locked {
val finalCheckpoint = stateMachines.remove(psm) stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
if (finalCheckpoint != null) {
checkpointStorage.removeCheckpoint(finalCheckpoint)
}
totalFinishedProtocols.inc() totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE) notifyChangeObservers(psm, AddOrRemove.REMOVE)
} }
endAllFiberSessions(psm)
} }
val checkpoint = startingCheckpoint()
checkpoint.fiberCreated = true
totalStartedProtocols.inc()
mutex.locked { mutex.locked {
stateMachines[psm] = checkpoint totalStartedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.ADD) notifyChangeObservers(psm, AddOrRemove.ADD)
} }
return checkpoint
} }
/** private fun endAllFiberSessions(psm: ProtocolStateMachineImpl<*>) {
* Kicks off a brand new state machine of the given class. It will log with the named logger. openSessions.values.removeIf { session ->
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is if (session.psm == psm) {
* restarted with checkpointed state machines in the storage service. val otherPartySessionId = session.otherPartySessionId
*/ if (otherPartySessionId != null) {
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> { sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm)
val id = StateMachineRunId.createRandom() }
val fiber = ProtocolStateMachineImpl(id, logic, scheduler, loggerName) recentlyClosedSessions[session.ourSessionId] = session.otherParty
// Need to add before iterating in case of immediate completion true
val checkpoint = initFiber(fiber) { } else {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null) false
checkpoint
} }
checkpointStorage.addCheckpoint(checkpoint)
mutex.locked { // If we are not started then our checkpoint will be picked up during start
if (!started) {
return fiber
} }
} }
private fun startFiber(fiber: ProtocolStateMachineImpl<*>) {
try { try {
executor.executeASAP { resumeFiber(fiber)
iterateStateMachine(fiber, null) {
fiber.start()
}
}
} catch (e: ExecutionException) { } catch (e: ExecutionException) {
// There are two ways we can take exceptions in this method: // There are two ways we can take exceptions in this method:
// //
@ -290,17 +325,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
if (e.cause !is ExecutionException) if (e.cause !is ExecutionException)
throw e throw e
} }
}
/**
* Kicks off a brand new state machine of the given class. It will log with the named logger.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = createFiber(loggerName, logic)
updateCheckpoint(fiber)
// If we are not started then our checkpoint will be picked up during start
mutex.locked {
if (started) {
startFiber(fiber)
}
}
return fiber return fiber
} }
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>) {
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
request: ProtocolIORequest?, val newCheckpoint = Checkpoint(serializeFiber(psm))
receivedPayload: Any?) { val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
val previousCheckpoint = mutex.locked {
stateMachines.put(psm, newCheckpoint)
}
if (previousCheckpoint != null) { if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint) checkpointStorage.removeCheckpoint(previousCheckpoint)
} }
@ -308,90 +355,70 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
checkpointingMeter.mark() checkpointingMeter.mark()
} }
private fun iterateStateMachine(psm: ProtocolStateMachineImpl<*>, private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
receivedPayload: Any?, executor.executeASAP {
resumeAction: (Any?) -> Unit) { psm.resume(scheduler)
executor.checkOnThread()
psm.receivedPayload = receivedPayload
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
resumeAction(receivedPayload)
}
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, request, null)
// We have a request to do something: send, receive, or send-and-receive.
if (request is ReceiveRequest<*>) {
// Prepare a listener on the network that runs in the background thread when we receive a message.
prepareToReceiveForRequest(psm, serialisedFiber, request)
}
if (request is SendRequest) {
performSendRequest(psm, request)
} }
} }
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, request: ReceiveRequest<*>) { private fun processIORequest(ioRequest: ProtocolIORequest) {
executor.checkOnThread() if (ioRequest is SendRequest) {
val queueID = request.receiveTopicSession if (ioRequest.message is SessionInit) {
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" } openSessions[ioRequest.session.ourSessionId] = ioRequest.session
iterateOnResponse(psm, serialisedFiber, request) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, it, queueID, psm)
} }
} sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm)
} if (ioRequest !is ReceiveRequest<*>) {
private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
val topicSession = sendMessage(psm, request)
if (request is SendOnly) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
iterateStateMachine(psm, null) { resumeFiber(ioRequest.session.psm)
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, request.payload, topicSession, psm)
}
} }
} }
} }
private fun sendMessage(psm: ProtocolStateMachineImpl<*>, request: SendRequest): TopicSession { private fun sendSessionMessage(party: Party, message: SessionMessage, psm: ProtocolStateMachineImpl<*>?) {
val topicSession = TopicSession(request.topic, request.sendSessionID) val node = serviceHub.networkMapCache.getNodeByLegalName(party.name)
val payload = request.payload ?: throw IllegalArgumentException("Don't know about party $party")
psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" } val logger = psm?.logger ?: logger
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?: logger.trace { "${psm?.id} sending $message to party $party" }
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems) serviceHub.networkService.send(sessionTopic, message, node.address)
serviceHub.networkService.send(topicSession, payload, node.address, request.uniqueMessageId)
return topicSession
} }
/**
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is interface SessionMessage
* received.
*/ interface ExistingSessionMessage: SessionMessage {
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, val recipientSessionId: Long
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, }
request: ReceiveRequest<*>,
resumeAction: (Any?) -> Unit) { data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
val topicSession = request.receiveTopicSession
serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg -> interface SessionInitResponse : ExistingSessionMessage
// Assertion to ensure we don't execute on the wrong thread.
executor.checkOnThread() data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
// TODO: This is insecure: we should not deserialise whatever we find and *then* check. override val recipientSessionId: Long get() = initiatorSessionId
// We should instead verify as we read the data that it's what we are expecting and throw as early as }
// possible. We only do it this way for convenience during the prototyping stage. Note that this means
// we could simply not require the programmer to specify the expected return type at all, and catch it data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
// at the last moment when we do the downcast. However this would make protocol code harder to read and override val recipientSessionId: Long get() = initiatorSessionId
// make it more difficult to migrate to a more explicit serialisation scheme later. }
val payload = netMsg.data.deserialize<Any>()
check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" } data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload override fun toString(): String {
updateCheckpoint(psm, serialisedFiber, null, payload) return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
iterateStateMachine(psm, payload, resumeAction)
} }
} }
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
data class ProtocolSession(val protocol: ProtocolLogic<*>,
val otherParty: Party,
val ourSessionId: Long,
var otherPartySessionId: Long?,
@Volatile var waitingForResponse: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>()
val psm: ProtocolStateMachineImpl<*> get() = protocol.psm as ProtocolStateMachineImpl<*>
}
} }

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services.transactions package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC import kotlin.reflect.KClass
/** /**
* A Notary service acts as the final signer of a transaction ensuring two things: * A Notary service acts as the final signer of a transaction ensuring two things:
@ -17,22 +16,18 @@ import com.r3corda.protocols.NotaryProtocol.TOPIC
* *
* This is the base implementation that can be customised with specific Notary transaction commit protocol. * This is the base implementation that can be customised with specific Notary transaction commit protocol.
*/ */
abstract class NotaryService(services: ServiceHubInternal, abstract class NotaryService(markerClass: KClass<out NotaryProtocol.Client>, services: ServiceHubInternal) : SingletonSerializeAsToken() {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : AbstractNodeService(services) {
// Do not specify this as an advertised service. Use a concrete implementation. // Do not specify this as an advertised service. Use a concrete implementation.
// TODO: We do not want a service type that cannot be used. Fix the type system abuse here. // TODO: We do not want a service type that cannot be used. Fix the type system abuse here.
object Type : ServiceType("corda.notary") object Type : ServiceType("corda.notary")
abstract val logger: org.slf4j.Logger abstract val logger: org.slf4j.Logger
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract val protocolFactory: NotaryProtocol.Factory
init { init {
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake -> services.registerProtocolInitiator(markerClass) { createProtocol(it) }
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
}
} }
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract fun createProtocol(otherParty: Party): NotaryProtocol.Service
} }

View File

@ -1,5 +1,6 @@
package com.r3corda.node.services.transactions package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider import com.r3corda.core.node.services.UniquenessProvider
@ -9,11 +10,13 @@ import com.r3corda.protocols.NotaryProtocol
/** A simple Notary service that does not perform transaction validation */ /** A simple Notary service that does not perform transaction validation */
class SimpleNotaryService(services: ServiceHubInternal, class SimpleNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker, val timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) { val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.Client::class, services) {
object Type : ServiceType("corda.notary.simple") object Type : ServiceType("corda.notary.simple")
override val logger = loggerFor<SimpleNotaryService>() override val logger = loggerFor<SimpleNotaryService>()
override val protocolFactory = NotaryProtocol.DefaultFactory override fun createProtocol(otherParty: Party): NotaryProtocol.Service {
return NotaryProtocol.Service(otherParty, timestampChecker, uniquenessProvider)
}
} }

View File

@ -11,17 +11,13 @@ import com.r3corda.protocols.ValidatingNotaryProtocol
/** A Notary service that validates the transaction chain of he submitted transaction before committing it */ /** A Notary service that validates the transaction chain of he submitted transaction before committing it */
class ValidatingNotaryService(services: ServiceHubInternal, class ValidatingNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker, val timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) { val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.ValidatingClient::class, services) {
object Type : ServiceType("corda.notary.validating") object Type : ServiceType("corda.notary.validating")
override val logger = loggerFor<ValidatingNotaryService>() override val logger = loggerFor<ValidatingNotaryService>()
override val protocolFactory = object : NotaryProtocol.Factory { override fun createProtocol(otherParty: Party): ValidatingNotaryProtocol {
override fun create(otherSide: Party, return ValidatingNotaryProtocol(otherParty, timestampChecker, uniquenessProvider)
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
}
} }
} }

View File

@ -10,6 +10,7 @@ import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.* import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
@ -23,7 +24,6 @@ import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.* import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
@ -89,11 +89,11 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// TODO: Verify that the result was inserted into the transaction database. // TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id]) // assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
assertEquals(aliceResult.get(), bobResult.get()) assertEquals(aliceResult.get(), bobPsm.get().resultFuture.get())
aliceNode.stop() aliceNode.stop()
bobNode.stop() bobNode.stop()
@ -120,21 +120,19 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerFuture val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerResult
// Everything is on this thread so we can now step through the protocol one step at a time. // Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once: // Seller Alice already sent a message to Buyer Bob. Pump once:
fun pumpAlice() = (aliceNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false) bobNode.pumpReceive(false)
fun pumpBob() = (bobNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
pumpBob()
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds. // Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
pumpAlice() aliceNode.pumpReceive(false)
pumpBob() bobNode.pumpReceive(false)
pumpAlice() aliceNode.pumpReceive(false)
pumpBob() bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1) assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
@ -147,7 +145,7 @@ class TwoPartyTradeProtocolTests {
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use. // Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
// She will wait around until Bob comes back. // She will wait around until Bob comes back.
assertThat(pumpAlice()).isNotNull() assertThat(aliceNode.pumpReceive(false)).isNotNull()
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
// that Bob was waiting on before the reboot occurred. // that Bob was waiting on before the reboot occurred.
@ -309,16 +307,16 @@ class TwoPartyTradeProtocolTests {
val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray())) val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray()))
val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second
val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode.services) insertFakeTransactions(bobsFakeCash, bobNode.services)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
net.runNetwork() // Clear network map registration messages net.runNetwork() // Clear network map registration messages
val aliceTxStream = aliceNode.storage.validatedTransactions.track().second val aliceTxStream = aliceNode.storage.validatedTransactions.track().second
val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second
val (bobResult, aliceResult, bobSmId, aliceSmId) = runBuyerAndSeller("alice's paper".outputStateAndRef()) val aliceSmId = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerId
net.runNetwork() net.runNetwork()
@ -367,21 +365,20 @@ class TwoPartyTradeProtocolTests {
} }
} }
data class RunResult( private data class RunResult(
val buyerFuture: Future<SignedTransaction>, // The buyer is not created immediately, only when the seller starts running
val sellerFuture: Future<SignedTransaction>, val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
val buyerSmId: StateMachineRunId, val sellerResult: Future<SignedTransaction>,
val sellerSmId: StateMachineRunId val sellerId: StateMachineRunId
) )
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult { private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java) val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
}
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY) val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller) val sellerResultFuture = aliceNode.smm.add("seller", seller).resultFuture
// We start the Buyer first, as the Seller sends the first message return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)
val buyerPsm = bobNode.smm.add("$TOPIC.buyer", buyer)
val sellerPsm = aliceNode.smm.add("$TOPIC.seller", seller)
return RunResult(buyerPsm.resultFuture, sellerPsm.resultFuture, buyerPsm.id, sellerPsm.id)
} }
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError( private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
@ -404,7 +401,7 @@ class TwoPartyTradeProtocolTests {
net.runNetwork() // Clear network map registration messages net.runNetwork() // Clear network map registration messages
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork() net.runNetwork()
@ -412,7 +409,7 @@ class TwoPartyTradeProtocolTests {
if (bobError) if (bobError)
aliceResult.get() aliceResult.get()
else else
bobResult.get() bobPsm.get().resultFuture.get()
} }
assertTrue(e.cause is TransactionVerificationException) assertTrue(e.cause is TransactionVerificationException)
assertNotNull(e.cause!!.cause) assertNotNull(e.cause!!.cause)
@ -506,6 +503,7 @@ class TwoPartyTradeProtocolTests {
return Pair(vault, listOf(ap)) return Pair(vault, listOf(ap))
} }
class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage { class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage {
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> { override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
return delegate.track() return delegate.track()
@ -530,4 +528,5 @@ class TwoPartyTradeProtocolTests {
data class Add(val transaction: SignedTransaction) : TxRecord data class Add(val transaction: SignedTransaction) : TxRecord
data class Get(val id: SecureHash) : TxRecord data class Get(val id: SecureHash) : TxRecord
} }
} }

View File

@ -2,23 +2,24 @@ package com.r3corda.node.services
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.* import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.serialization.NodeClock import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.api.MonitoringService import com.r3corda.node.services.api.MonitoringService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.DataVending import com.r3corda.node.services.persistence.DataVending
import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.testing.node.MockStorageService
import com.r3corda.testing.MOCK_IDENTITY_SERVICE import com.r3corda.testing.MOCK_IDENTITY_SERVICE
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.testing.node.MockStorageService
import java.time.Clock import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
@Suppress("LeakingThis") @Suppress("LeakingThis")
open class MockServiceHubInternal( open class MockServiceHubInternal(
@ -28,7 +29,6 @@ open class MockServiceHubInternal(
val identity: IdentityService? = MOCK_IDENTITY_SERVICE, val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
val storage: TxWritableStorageService? = MockStorageService(), val storage: TxWritableStorageService? = MockStorageService(),
val mapCache: NetworkMapCache? = MockNetworkMapCache(), val mapCache: NetworkMapCache? = MockNetworkMapCache(),
val mapService: NetworkMapService? = null,
val scheduler: SchedulerService? = null, val scheduler: SchedulerService? = null,
val overrideClock: Clock? = NodeClock(), val overrideClock: Clock? = NodeClock(),
val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory() val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory()
@ -57,14 +57,10 @@ open class MockServiceHubInternal(
private val txStorageService: TxWritableStorageService private val txStorageService: TxWritableStorageService
get() = storage ?: throw UnsupportedOperationException() get() = storage ?: throw UnsupportedOperationException()
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs) private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
lateinit var smm: StateMachineManager lateinit var smm: StateMachineManager
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
init { init {
if (net != null && storage != null) { if (net != null && storage != null) {
// Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener // Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener
@ -72,4 +68,18 @@ open class MockServiceHubInternal(
DataVending.Service(this) DataVending.Service(this)
} }
} }
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
} }

View File

@ -3,7 +3,6 @@ package com.r3corda.node.services
import com.google.common.jimfs.Configuration import com.google.common.jimfs.Configuration
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
@ -12,18 +11,16 @@ import com.r3corda.core.protocols.ProtocolLogicRef
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.testing.node.TestClock
import com.r3corda.node.services.events.NodeSchedulerService import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.services.vault.NodeVaultService
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase import com.r3corda.node.utilities.configureDatabase
import com.r3corda.testing.ALICE_KEY import com.r3corda.testing.ALICE_KEY
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockKeyManagementService import com.r3corda.testing.node.MockKeyManagementService
import com.r3corda.testing.node.TestClock
import com.r3corda.testing.node.makeTestDataSourceProperties import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
@ -34,7 +31,9 @@ import java.nio.file.FileSystem
import java.security.PublicKey import java.security.PublicKey
import java.time.Clock import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.concurrent.* import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertTrue import kotlin.test.assertTrue
class NodeSchedulerServiceTest : SingletonSerializeAsToken() { class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@ -128,8 +127,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
(serviceHub as TestReference).testReference.calls += increment (serviceHub as TestReference).testReference.calls += increment
(serviceHub as TestReference).testReference.countDown.countDown() (serviceHub as TestReference).testReference.countDown.countDown()
} }
override val topic: String get() = throw UnsupportedOperationException()
} }
class Command : TypeOnlyCommandData() class Command : TypeOnlyCommandData()

View File

@ -9,7 +9,6 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.Instigator import com.r3corda.protocols.NotaryChangeProtocol.Instigator
import com.r3corda.protocols.StateReplacementException import com.r3corda.protocols.StateReplacementException
import com.r3corda.protocols.StateReplacementRefused import com.r3corda.protocols.StateReplacementRefused
@ -49,7 +48,7 @@ class NotaryChangeTests {
val state = issueState(clientNodeA) val state = issueState(clientNodeA)
val newNotary = newNotaryNode.info.identity val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary) val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork() net.runNetwork()
@ -62,7 +61,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB) val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newNotary = newNotaryNode.info.identity val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary) val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork() net.runNetwork()
@ -78,7 +77,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB) val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newEvilNotary = Party("Evil Notary", generateKeyPair().public) val newEvilNotary = Party("Evil Notary", generateKeyPair().public)
val protocol = Instigator(state, newEvilNotary) val protocol = Instigator(state, newEvilNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork() net.runNetwork()

View File

@ -1,17 +1,19 @@
package com.r3corda.node.services package com.r3corda.node.services
import com.r3corda.core.contracts.Timestamp import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.TransactionType import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryError import com.r3corda.protocols.NotaryError
import com.r3corda.protocols.NotaryException import com.r3corda.protocols.NotaryException
import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.NotaryProtocol
import com.r3corda.testing.MINI_CORP_KEY import com.r3corda.testing.MINI_CORP_KEY
import com.r3corda.testing.node.MockNetwork
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.time.Instant import java.time.Instant
@ -45,10 +47,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false) tx.toSignedTransaction(false)
} }
val protocol = NotaryProtocol.Client(stx) val future = runNotaryClient(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val signature = future.get() val signature = future.get()
signature.verifyWithECDSA(stx.txBits) signature.verifyWithECDSA(stx.txBits)
} }
@ -61,10 +60,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false) tx.toSignedTransaction(false)
} }
val protocol = NotaryProtocol.Client(stx) val future = runNotaryClient(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val signature = future.get() val signature = future.get()
signature.verifyWithECDSA(stx.txBits) signature.verifyWithECDSA(stx.txBits)
} }
@ -78,16 +74,13 @@ class NotaryServiceTests {
tx.toSignedTransaction(false) tx.toSignedTransaction(false)
} }
val protocol = NotaryProtocol.Client(stx) val future = runNotaryClient(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val ex = assertFailsWith(ExecutionException::class) { future.get() } val ex = assertFailsWith(ExecutionException::class) { future.get() }
val error = (ex.cause as NotaryException).error val error = (ex.cause as NotaryException).error
assertTrue(error is NotaryError.TimestampInvalid) assertTrue(error is NotaryError.TimestampInvalid)
} }
@Test fun `should report conflict for a duplicate transaction`() { @Test fun `should report conflict for a duplicate transaction`() {
val stx = run { val stx = run {
val inputState = issueState(clientNode) val inputState = issueState(clientNode)
@ -98,8 +91,8 @@ class NotaryServiceTests {
val firstSpend = NotaryProtocol.Client(stx) val firstSpend = NotaryProtocol.Client(stx)
val secondSpend = NotaryProtocol.Client(stx) val secondSpend = NotaryProtocol.Client(stx)
clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.first", firstSpend) clientNode.services.startProtocol("notary.first", firstSpend)
val future = clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.second", secondSpend) val future = clientNode.services.startProtocol("notary.second", secondSpend)
net.runNetwork() net.runNetwork()
@ -108,4 +101,12 @@ class NotaryServiceTests {
assertEquals(notaryError.tx, stx.tx) assertEquals(notaryError.tx, stx.tx)
notaryError.conflict.verified() notaryError.conflict.verified()
} }
private fun runNotaryClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol("notary-test", protocol)
net.runNetwork()
return future
}
} }

View File

@ -1,8 +1,11 @@
package com.r3corda.node.services package com.r3corda.node.services
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.Command import com.r3corda.core.contracts.Command
import com.r3corda.core.contracts.DummyContract import com.r3corda.core.contracts.DummyContract
import com.r3corda.core.contracts.TransactionType import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
@ -44,9 +47,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false) tx.toSignedTransaction(false)
} }
val protocol = NotaryProtocol.Client(stx) val future = runValidatingClient(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val ex = assertFailsWith(ExecutionException::class) { future.get() } val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error val notaryError = (ex.cause as NotaryException).error
@ -64,9 +65,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false) tx.toSignedTransaction(false)
} }
val protocol = NotaryProtocol.Client(stx) val future = runValidatingClient(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val ex = assertFailsWith(ExecutionException::class) { future.get() } val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error val notaryError = (ex.cause as NotaryException).error
@ -75,4 +74,11 @@ class ValidatingNotaryServiceTests {
val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners
assertEquals(setOf(expectedMissingKey), missingKeys) assertEquals(setOf(expectedMissingKey), missingKeys)
} }
private fun runValidatingClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.ValidatingClient(stx)
val future = clientNode.services.startProtocol("notary", protocol)
net.runNetwork()
return future
}
} }

View File

@ -1,13 +1,20 @@
package com.r3corda.node.services.persistence package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.contracts.asset.Cash import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.Amount import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Issued import com.r3corda.core.contracts.Issued
import com.r3corda.core.contracts.TransactionType import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.contracts.USD import com.r3corda.core.contracts.USD
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.testing.node.MockNetwork import com.r3corda.node.services.persistence.DataVending.Service.NotifyTransactionHandler
import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
import com.r3corda.testing.MEGA_CORP import com.r3corda.testing.MEGA_CORP
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -38,9 +45,8 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey) ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction() val tx = ptx.toSignedTransaction()
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx) registerNode.sendNotifyTx(tx, vaultServiceNode)
network.runNetwork()
// Check the transaction is in the receiving node // Check the transaction is in the receiving node
val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull() val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull()
@ -67,11 +73,23 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey) ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction(false) val tx = ptx.toSignedTransaction(false)
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx) registerNode.sendNotifyTx(tx, vaultServiceNode)
network.runNetwork()
// Check the transaction is not in the receiving node // Check the transaction is not in the receiving node
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
} }
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.services.registerProtocolInitiator(NotifyTxProtocol::class, ::NotifyTransactionHandler)
services.startProtocol("notify-tx", NotifyTxProtocol(walletServiceNode.info.identity, tx))
network.runNetwork()
}
private class NotifyTxProtocol(val otherParty: Party, val stx: SignedTransaction) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, NotifyTxRequest(stx, emptySet()))
}
} }

View File

@ -10,12 +10,14 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path
class PerFileCheckpointStorageTests { class PerFileCheckpointStorageTests {
val fileSystem = Jimfs.newFileSystem(unix()) val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir = fileSystem.getPath("store") val storeDir: Path = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage lateinit var checkpointStorage: PerFileCheckpointStorage
@Before @Before
@ -92,6 +94,6 @@ class PerFileCheckpointStorageTests {
} }
private var checkpointCount = 1 private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null) private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
} }

View File

@ -2,14 +2,20 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.random63BitValue import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -50,18 +56,18 @@ class StateMachineManagerTests {
} }
@Test @Test
fun `protocol suspended just after receiving payload`() { fun `protocol restarted just after receiving payload`() {
val topic = "send-and-receive" node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue() val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload) node1.smm.add("test", SendProtocol(payload, node2.info.identity))
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol) // We push through just enough messages to get only the SessionData sent
node1.smm.add("test", sendProtocol) // TODO We should be able to give runNetwork a predicate for when to stop
node2.smm.add("test", receiveProtocol) net.runNetwork(2)
net.runNetwork()
node2.stop() node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address) net.runNetwork()
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
} }
@Test @Test
@ -83,7 +89,7 @@ class StateMachineManagerTests {
node3.stop() node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id) node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush() node3.smm.executor.flush()
@ -99,43 +105,44 @@ class StateMachineManagerTests {
@Test @Test
fun `protocol loaded from checkpoint will respond to messages from before start`() { fun `protocol loaded from checkpoint will respond to messages from before start`() {
val topic = "send-and-receive"
val payload = random63BitValue() val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload) node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity) val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver node2.stop() // kill receiver
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
} }
@Test @Test
fun `protocol with send will resend on interrupted restart`() { fun `protocol with send will resend on interrupted restart`() {
val topic = "send-and-receive"
val payload = random63BitValue() val payload = random63BitValue()
val payload2 = random63BitValue() val payload2 = random63BitValue()
var sentCount = 0 var sentCount = 0
var receivedCount = 0 var receivedCount = 0
net.messagingNetwork.sentMessages.subscribe { if (it.message.topicSession.topic == topic) sentCount++ } net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (it.message.topicSession.topic == topic) receivedCount++ } net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
val node3 = net.createNode(node1.info.address) val node3 = net.createNode(node1.info.address)
net.runNetwork() net.runNetwork()
val firstProtocol = PingPongProtocol(topic, node3.info.identity, payload)
val secondProtocol = PingPongProtocol(topic, node2.info.identity, payload2) var secondProtocol: PingPongProtocol? = null
connectProtocols(firstProtocol, secondProtocol) node3.services.registerProtocolInitiator(PingPongProtocol::class) {
val protocol = PingPongProtocol(it, payload2)
secondProtocol = protocol
protocol
}
// Kick off first send and receive // Kick off first send and receive
node2.smm.add("test", firstProtocol) node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints.count()) assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Restart node and thus reload the checkpoint and resend the message with same UUID // Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop() node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.smm.findStateMachines(PingPongProtocol::class.java).single() val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork() net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints.count()) assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Now add in the other half of the protocol. First message should get deduped. So message data stays in sync. // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
node3.smm.add("test", secondProtocol)
net.runNetwork() net.runNetwork()
node2b.smm.executor.flush() node2b.smm.executor.flush()
fut1.get() fut1.get()
@ -146,15 +153,66 @@ class StateMachineManagerTests {
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol.receivedPayload2, "Received payload does not match the expected second value on Node 2") assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
@Test
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
node3.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue()
node1.smm.add("multiple-send", SendProtocol(payload, node2.info.identity, node3.info.identity))
net.runNetwork()
val node2Protocol = node2.getSingleProtocol<ReceiveThenSuspendProtocol>().first
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `receiving from multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val node2Payload = random63BitValue()
val node3Payload = random63BitValue()
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node2Payload, it) }
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.identity, node3.info.identity)
node1.smm.add("multiple-receive", multiReceiveProtocol)
net.runNetwork(1) // session handshaking
// have the messages arrive in reverse order of receive
node3.pumpReceive(false)
node2.pumpReceive(false)
net.runNetwork() // pump remaining messages
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
}
@Test
fun `exception thrown on other side`() {
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { ExceptionProtocol }
val future = node1.smm.add("exception", ReceiveThenSuspendProtocol(node2.info.identity)).resultFuture
net.runNetwork()
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
}
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
return transfer.message.topicSession == StateMachineManager.sessionTopic
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
} }
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P { private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
val servicesArray = advertisedServices.toTypedArray() stop()
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray) val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return node.smm.findStateMachines(P::class.java).single().first return newNode.getSingleProtocol<P>().first
}
private inline fun <reified P : ProtocolLogic<*>> MockNode.getSingleProtocol(): Pair<P, ListenableFuture<*>> {
return smm.findStateMachines(P::class.java).single()
} }
@ -165,8 +223,6 @@ class StateMachineManagerTests {
override fun call() { override fun call() {
protocolStarted = true protocolStarted = true
} }
override val topic: String get() = throw UnsupportedOperationException()
} }
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() { private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@ -177,8 +233,6 @@ class StateMachineManagerTests {
override fun doCall() { override fun doCall() {
protocolStarted = true protocolStarted = true
} }
override val topic: String get() = throw UnsupportedOperationException()
} }
@ -187,30 +241,37 @@ class StateMachineManagerTests {
val lazyTime by lazy { serviceHub.clock.instant() } val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable @Suspendable
override fun call() { override fun call() = Unit
}
override val topic: String get() = throw UnsupportedOperationException()
} }
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() { private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
}
@Suspendable @Suspendable
override fun call() = send(otherParty, payload) override fun call() = otherParties.forEach { send(it, payload) }
} }
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() { private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null init {
require(otherParties.isNotEmpty())
}
@Transient var receivedPayloads: List<Any> = emptyList()
@Suspendable @Suspendable
override fun doCall() { override fun doCall() {
receivedPayload = receive<Any>(otherParty).unwrap { it } receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
} }
} }
private class PingPongProtocol(override val topic: String, val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() { private class PingPongProtocol(val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
@Transient var receivedPayload: Long? = null @Transient var receivedPayload: Long? = null
@Transient var receivedPayload2: Long? = null @Transient var receivedPayload2: Long? = null
@ -219,7 +280,10 @@ class StateMachineManagerTests {
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it } receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it } receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
} }
}
private object ExceptionProtocol : ProtocolLogic<Nothing>() {
override fun call(): Nothing = throw Exception()
} }
/** /**

View File

@ -7,20 +7,22 @@ import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER
import com.r3corda.contracts.asset.cashBalances import com.r3corda.contracts.asset.cashBalances
import com.r3corda.contracts.testing.fillWithSomeTestCash import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.*
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.days
import com.r3corda.core.logElapsedTime
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.seconds
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.Emoji import com.r3corda.core.utilities.Emoji
import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.internal.Node import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.config.NodeConfigurationFromConfig import com.r3corda.node.services.config.NodeConfigurationFromConfig
import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.messaging.NodeMessagingClient
@ -28,7 +30,6 @@ import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.utilities.databaseTransaction import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol import com.r3corda.protocols.TwoPartyTradeProtocol
import joptsimple.OptionParser import joptsimple.OptionParser
@ -210,23 +211,13 @@ private fun runBuyer(node: Node, amount: Amount<Currency>) {
// next stage in our building site, we will just auto-generate fake trades to give our nodes something to do. // next stage in our building site, we will just auto-generate fake trades to give our nodes something to do.
// //
// As the seller initiates the two-party trade protocol, here, we will be the buyer. // As the seller initiates the two-party trade protocol, here, we will be the buyer.
object : AbstractNodeService(node.services) { node.services.registerProtocolInitiator(TraderDemoProtocolSeller::class) { otherParty ->
init { TraderDemoProtocolBuyer(otherParty, attachmentsPath, amount)
addProtocolHandler(DEMO_TOPIC, "demo.buyer") { handshake: TraderDemoHandshake ->
TraderDemoProtocolBuyer(handshake.replyToParty, attachmentsPath, amount)
}
}
} }
} }
// We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic. // We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic.
private data class TraderDemoHandshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
private val DEMO_TOPIC = "initiate.demo.trade"
private class TraderDemoProtocolBuyer(val otherSide: Party, private class TraderDemoProtocolBuyer(val otherSide: Party,
private val attachmentsPath: Path, private val attachmentsPath: Path,
val amount: Amount<Currency>, val amount: Amount<Currency>,
@ -234,8 +225,6 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset") object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset")
override val topic: String get() = DEMO_TOPIC
@Suspendable @Suspendable
override fun call() { override fun call() {
progressTracker.currentStep = STARTING_BUY progressTracker.currentStep = STARTING_BUY
@ -248,7 +237,7 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
CommercialPaper.State::class.java) CommercialPaper.State::class.java)
// This invokes the trading protocol and out pops our finished transaction. // This invokes the trading protocol and out pops our finished transaction.
val tradeTX: SignedTransaction = subProtocol(buyer, inheritParentSessions = true) val tradeTX: SignedTransaction = subProtocol(buyer, shareParentSessions = true)
// TODO: This should be moved into the protocol itself. // TODO: This should be moved into the protocol itself.
serviceHub.recordTransactions(listOf(tradeTX)) serviceHub.recordTransactions(listOf(tradeTX))
@ -289,8 +278,6 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
companion object { companion object {
val PROSPECTUS_HASH = SecureHash.parse("decd098666b9657314870e192ced0c3519c2c9d395507a238338f8d003929de9") val PROSPECTUS_HASH = SecureHash.parse("decd098666b9657314870e192ced0c3519c2c9d395507a238338f8d003929de9")
object ANNOUNCING : ProgressTracker.Step("Announcing to the buyer node")
object SELF_ISSUING : ProgressTracker.Step("Got session ID back, issuing and timestamping some commercial paper") object SELF_ISSUING : ProgressTracker.Step("Got session ID back, issuing and timestamping some commercial paper")
object TRADING : ProgressTracker.Step("Starting the trade protocol") { object TRADING : ProgressTracker.Step("Starting the trade protocol") {
@ -300,17 +287,11 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
// We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some
// point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being
// surprised when it appears as a new set of tasks below the current one. // surprised when it appears as a new set of tasks below the current one.
fun tracker() = ProgressTracker(ANNOUNCING, SELF_ISSUING, TRADING) fun tracker() = ProgressTracker(SELF_ISSUING, TRADING)
} }
override val topic: String get() = DEMO_TOPIC
@Suspendable @Suspendable
override fun call(): SignedTransaction { override fun call(): SignedTransaction {
progressTracker.currentStep = ANNOUNCING
send(otherSide, TraderDemoHandshake(serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = SELF_ISSUING progressTracker.currentStep = SELF_ISSUING
val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0] val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0]
@ -326,7 +307,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
amount, amount,
cpOwnerKey, cpOwnerKey,
progressTracker.getChildProgressTracker(TRADING)!!) progressTracker.getChildProgressTracker(TRADING)!!)
val tradeTX: SignedTransaction = subProtocol(seller, inheritParentSessions = true) val tradeTX: SignedTransaction = subProtocol(seller, shareParentSessions = true)
serviceHub.recordTransactions(listOf(tradeTX)) serviceHub.recordTransactions(listOf(tradeTX))
return tradeTX return tradeTX

View File

@ -12,16 +12,14 @@ import com.r3corda.core.math.InterpolatorFactory
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.AcceptsFileUpload import com.r3corda.node.services.api.AcceptsFileUpload
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.FiberBox import com.r3corda.node.utilities.FiberBox
import com.r3corda.protocols.RatesFixProtocol import com.r3corda.protocols.RatesFixProtocol.*
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol
import org.slf4j.LoggerFactory
import java.io.InputStream import java.io.InputStream
import java.math.BigDecimal import java.math.BigDecimal
import java.security.KeyPair import java.security.KeyPair
@ -55,46 +53,31 @@ object NodeInterestRates {
/** /**
* The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing. * The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing.
*/ */
class Service(services: ServiceHubInternal) : AcceptsFileUpload, AbstractNodeService(services) { class Service(services: ServiceHubInternal) : AcceptsFileUpload, SingletonSerializeAsToken() {
val ss = services.storageService val ss = services.storageService
val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey, services.clock) val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey, services.clock)
private val logger = LoggerFactory.getLogger(Service::class.java)
init { init {
addMessageHandler(RatesFixProtocol.TOPIC, services.registerProtocolInitiator(FixSignProtocol::class) { FixSignHandler(it, oracle) }
{ req: ServiceRequestMessage -> services.registerProtocolInitiator(FixQueryProtocol::class) { FixQueryHandler(it, oracle) }
if (req is RatesFixProtocol.SignRequest) {
oracle.sign(req.tx)
}
else {
/**
* We put this into a protocol so that if it blocks waiting for the interest rate to become
* available, we a) don't block this thread and b) allow the fact we are waiting
* to be persisted/checkpointed.
* Interest rates become available when they are uploaded via the web as per [DataUploadServlet],
* if they haven't already been uploaded that way.
*/
req as RatesFixProtocol.QueryRequest
val handler = FixQueryHandler(this, req)
handler.registerSession(req)
services.startProtocol("fixing", handler)
Unit
}
},
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
)
} }
private class FixQueryHandler(val service: Service,
val request: RatesFixProtocol.QueryRequest) : ProtocolLogic<Unit>() { private class FixSignHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<SignRequest>(otherParty).unwrap { it }
send(otherParty, oracle.sign(request.tx))
}
}
private class FixQueryHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic<Unit>() {
companion object { companion object {
object RECEIVED : ProgressTracker.Step("Received fix request") object RECEIVED : ProgressTracker.Step("Received fix request")
object SENDING : ProgressTracker.Step("Sending fix response") object SENDING : ProgressTracker.Step("Sending fix response")
} }
override val topic: String get() = RatesFixProtocol.TOPIC
override val progressTracker = ProgressTracker(RECEIVED, SENDING) override val progressTracker = ProgressTracker(RECEIVED, SENDING)
init { init {
@ -103,9 +86,10 @@ object NodeInterestRates {
@Suspendable @Suspendable
override fun call(): Unit { override fun call(): Unit {
val answers = service.oracle.query(request.queries, request.deadline) val request = receive<QueryRequest>(otherParty).unwrap { it }
val answers = oracle.query(request.queries, request.deadline)
progressTracker.currentStep = SENDING progressTracker.currentStep = SENDING
send(request.replyToParty, answers) send(otherParty, answers)
} }
} }

View File

@ -1,20 +1,17 @@
package com.r3corda.demos.protocols package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.FutureCallback
import com.r3corda.core.contracts.DealState import com.r3corda.core.contracts.DealState
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.protocols.TwoPartyDealProtocol.DEAL_TOPIC import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer
import com.r3corda.protocols.TwoPartyDealProtocol.Instigator import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
/** /**
@ -25,58 +22,27 @@ import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
* or the protocol would have to reach out to external systems (or users) to verify the deals. * or the protocol would have to reach out to external systems (or users) to verify the deals.
*/ */
object AutoOfferProtocol { object AutoOfferProtocol {
val TOPIC = "autooffer.topic"
data class AutoOfferMessage(val notary: Party,
val dealBeingOffered: DealState,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin : CordaPluginRegistry() { class Plugin : CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) : AbstractNodeService(services) { class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
object DEALING : ProgressTracker.Step("Starting the deal protocol") { object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Secondary.tracker() override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Secondary.tracker()
} }
fun tracker() = ProgressTracker(DEALING)
class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback<SignedTransaction> {
override fun onFailure(t: Throwable?) {
// TODO handle exceptions
}
override fun onSuccess(st: SignedTransaction?) {
success(st!!)
}
}
init { init {
addProtocolHandler(TOPIC, "$DEAL_TOPIC.seller") { autoOfferMessage: AutoOfferMessage -> services.registerProtocolInitiator(Instigator::class) { Acceptor(it) }
val progressTracker = tracker()
// Put the deal onto the ledger
progressTracker.currentStep = DEALING
Acceptor(
autoOfferMessage.replyToParty,
autoOfferMessage.notary,
autoOfferMessage.dealBeingOffered,
progressTracker.getChildProgressTracker(DEALING)!!
)
} }
} }
}
class Requester(val dealToBeOffered: DealState) : ProtocolLogic<SignedTransaction>() { class Requester(val dealToBeOffered: DealState) : ProtocolLogic<SignedTransaction>() {
companion object { companion object {
object RECEIVED : ProgressTracker.Step("Received API call") object RECEIVED : ProgressTracker.Step("Received API call")
object ANNOUNCING : ProgressTracker.Step("Announcing to the peer node")
object DEALING : ProgressTracker.Step("Starting the deal protocol") { object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker() override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker()
} }
@ -84,10 +50,9 @@ object AutoOfferProtocol {
// We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some
// point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being
// surprised when it appears as a new set of tasks below the current one. // surprised when it appears as a new set of tasks below the current one.
fun tracker() = ProgressTracker(RECEIVED, ANNOUNCING, DEALING) fun tracker() = ProgressTracker(RECEIVED, DEALING)
} }
override val topic: String get() = TOPIC
override val progressTracker = tracker() override val progressTracker = tracker()
init { init {
@ -100,17 +65,14 @@ object AutoOfferProtocol {
val notary = serviceHub.networkMapCache.notaryNodes.first().identity val notary = serviceHub.networkMapCache.notaryNodes.first().identity
// need to pick which ever party is not us // need to pick which ever party is not us
val otherParty = notUs(dealToBeOffered.parties).single() val otherParty = notUs(dealToBeOffered.parties).single()
progressTracker.currentStep = ANNOUNCING
send(otherParty, AutoOfferMessage(notary, dealToBeOffered, serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = DEALING progressTracker.currentStep = DEALING
val instigator = Instigator( val instigator = Instigator(
otherParty, otherParty,
notary, AutoOffer(notary, dealToBeOffered),
dealToBeOffered,
serviceHub.storageService.myLegalIdentityKey, serviceHub.storageService.myLegalIdentityKey,
progressTracker.getChildProgressTracker(DEALING)!! progressTracker.getChildProgressTracker(DEALING)!!
) )
val stx = subProtocol(instigator, inheritParentSessions = true) val stx = subProtocol(instigator)
return stx return stx
} }

View File

@ -5,56 +5,49 @@ import co.paralleluniverse.strands.Strand
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache import com.r3corda.testing.node.MockNetworkMapCache
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
object ExitServerProtocol { object ExitServerProtocol {
val TOPIC = "exit.topic"
// Will only be enabled if you install the Handler // Will only be enabled if you install the Handler
@Volatile private var enabled = false @Volatile private var enabled = false
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done. // resolve itself when the protocol session stuff is done.
data class ExitMessage(val exitCode: Int, data class ExitMessage(val exitCode: Int)
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) { class Service(services: ServiceHubInternal) {
init { init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration -> services.registerProtocolInitiator(Broadcast::class, ::ExitServerHandler)
// Just to validate we got the message
if (enabled) {
val message = msg.data.deserialize<ExitMessage>()
System.exit(message.exitCode)
}
}
enabled = true enabled = true
} }
} }
private class ExitServerHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
override fun call() {
// Just to validate we got the message
if (enabled) {
val message = receive<ExitMessage>(otherParty).unwrap { it }
System.exit(message.exitCode)
}
}
}
/** /**
* This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently * This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently
* we do not support coercing numeric types in the reflective search for matching constructors. * we do not support coercing numeric types in the reflective search for matching constructors.
*/ */
class Broadcast(val exitCode: Int) : ProtocolLogic<Boolean>() { class Broadcast(val exitCode: Int) : ProtocolLogic<Boolean>() {
override val topic: String get() = TOPIC
@Suspendable @Suspendable
override fun call(): Boolean { override fun call(): Boolean {
if (enabled) { if (enabled) {
@ -73,7 +66,7 @@ object ExitServerProtocol {
if (recipient.address is MockNetworkMapCache.MockAddress) { if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore // Ignore
} else { } else {
send(recipient.identity, ExitMessage(exitCode, recipient.identity)) send(recipient.identity, ExitMessage(exitCode))
} }
} }
} }

View File

@ -4,14 +4,10 @@ import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.demos.DemoClock import com.r3corda.demos.DemoClock
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache import com.r3corda.testing.node.MockNetworkMapCache
import java.time.LocalDate import java.time.LocalDate
@ -20,28 +16,27 @@ import java.time.LocalDate
*/ */
object UpdateBusinessDayProtocol { object UpdateBusinessDayProtocol {
val TOPIC = "businessday.topic"
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done. // resolve itself when the protocol session stuff is done.
data class UpdateBusinessDayMessage(val date: LocalDate, data class UpdateBusinessDayMessage(val date: LocalDate)
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() { class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java) override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
} }
class Service(services: ServiceHubInternal) { class Service(services: ServiceHubInternal) {
init { init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration -> services.registerProtocolInitiator(Broadcast::class, ::UpdateBusinessDayHandler)
val updateBusinessDayMessage = msg.data.deserialize<UpdateBusinessDayMessage>()
(services.clock as DemoClock).updateDate(updateBusinessDayMessage.date)
} }
} }
private class UpdateBusinessDayHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
override fun call() {
val message = receive<UpdateBusinessDayMessage>(otherParty).unwrap { it }
(serviceHub.clock as DemoClock).updateDate(message.date)
} }
}
class Broadcast(val date: LocalDate, class Broadcast(val date: LocalDate,
override val progressTracker: ProgressTracker = Broadcast.tracker()) : ProtocolLogic<Unit>() { override val progressTracker: ProgressTracker = Broadcast.tracker()) : ProtocolLogic<Unit>() {
@ -52,8 +47,6 @@ object UpdateBusinessDayProtocol {
fun tracker() = ProgressTracker(NOTIFYING) fun tracker() = ProgressTracker(NOTIFYING)
} }
override val topic: String get() = TOPIC
@Suspendable @Suspendable
override fun call(): Unit { override fun call(): Unit {
progressTracker.currentStep = NOTIFYING progressTracker.currentStep = NOTIFYING
@ -67,7 +60,7 @@ object UpdateBusinessDayProtocol {
if (recipient.address is MockNetworkMapCache.MockAddress) { if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore // Ignore
} else { } else {
send(recipient.identity, UpdateBusinessDayMessage(date, recipient.identity)) send(recipient.identity, UpdateBusinessDayMessage(date))
} }
} }
} }

View File

@ -10,14 +10,16 @@ import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.contracts.StateAndRef import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.contracts.UniqueIdentifier import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.failure import com.r3corda.core.failure
import com.r3corda.core.flatMap
import com.r3corda.core.node.services.linearHeadsOfType import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.success import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.testing.connectProtocols import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer
import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockIdentityService import com.r3corda.testing.node.MockIdentityService
import java.security.KeyPair
import java.time.LocalDate import java.time.LocalDate
import java.util.* import java.util.*
@ -73,7 +75,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
val node1: SimulatedNode = banks[i] val node1: SimulatedNode = banks[i]
val node2: SimulatedNode = banks[j] val node2: SimulatedNode = banks[j]
val swaps: Map<UniqueIdentifier, StateAndRef<InterestRateSwap.State>> = node1.services.vaultService.linearHeadsOfType<com.r3corda.contracts.InterestRateSwap.State>() val swaps: Map<UniqueIdentifier, StateAndRef<InterestRateSwap.State>> = node1.services.vaultService.linearHeadsOfType<InterestRateSwap.State>()
val theDealRef: StateAndRef<InterestRateSwap.State> = swaps.values.single() val theDealRef: StateAndRef<InterestRateSwap.State> = swaps.values.single()
// Do we have any more days left in this deal's lifetime? If not, return. // Do we have any more days left in this deal's lifetime? If not, return.
@ -111,22 +113,19 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
// We load the IRS afresh each time because the leg parts of the structure aren't data classes so they don't // We load the IRS afresh each time because the leg parts of the structure aren't data classes so they don't
// have the convenient copy() method that'd let us make small adjustments. Instead they're partly mutable. // have the convenient copy() method that'd let us make small adjustments. Instead they're partly mutable.
// TODO: We should revisit this in post-Excalibur cleanup and fix, e.g. by introducing an interface. // TODO: We should revisit this in post-Excalibur cleanup and fix, e.g. by introducing an interface.
val irs = om.readValue<com.r3corda.contracts.InterestRateSwap.State>(javaClass.getResource("trade.json")) val irs = om.readValue<InterestRateSwap.State>(javaClass.getResource("trade.json"))
irs.fixedLeg.fixedRatePayer = node1.info.identity irs.fixedLeg.fixedRatePayer = node1.info.identity
irs.floatingLeg.floatingRatePayer = node2.info.identity irs.floatingLeg.floatingRatePayer = node2.info.identity
val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, node1.keyPair!!) val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture }
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs)
connectProtocols(instigator, acceptor)
showProgressFor(listOf(node1, node2)) showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0])) showConsensusFor(listOf(node1, node2, regulators[0]))
val instigatorFuture: ListenableFuture<SignedTransaction> = node1.services.startProtocol("instigator", instigator) val instigator = Instigator(node2.info.identity, AutoOffer(notary.info.identity, irs), node1.keyPair!!)
val instigatorTx = node1.services.startProtocol("instigator", instigator)
return Futures.transformAsync(Futures.allAsList(instigatorFuture, node2.services.startProtocol("acceptor", acceptor))) { return Futures.transformAsync(Futures.allAsList(instigatorTx, acceptorTx)) { instigatorTx }
instigatorFuture
}
} }
override fun iterate(): InMemoryMessagingNetwork.MessageTransfer? { override fun iterate(): InMemoryMessagingNetwork.MessageTransfer? {

View File

@ -9,12 +9,13 @@ import com.r3corda.core.contracts.DOLLARS
import com.r3corda.core.contracts.OwnableState import com.r3corda.core.contracts.OwnableState
import com.r3corda.core.contracts.`issued by` import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days import com.r3corda.core.days
import com.r3corda.core.flatMap
import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.recordTransactions
import com.r3corda.core.seconds import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.testing.connectProtocols import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import java.time.Instant import java.time.Instant
@ -45,25 +46,24 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
seller.services.recordTransactions(issuance) seller.services.recordTransactions(issuance)
val amount = 1000.DOLLARS val amount = 1000.DOLLARS
val buyerProtocol = TwoPartyTradeProtocol.Buyer(
seller.info.identity, val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) {
notary.info.identity, Buyer(it, notary.info.identity, amount, CommercialPaper.State::class.java)
amount, }.flatMap { it.resultFuture }
CommercialPaper.State::class.java)
val sellerProtocol = TwoPartyTradeProtocol.Seller( val sellerProtocol = Seller(
buyer.info.identity, buyer.info.identity,
notary.info, notary.info,
issuance.tx.outRef<OwnableState>(0), issuance.tx.outRef<OwnableState>(0),
amount, amount,
seller.storage.myLegalIdentityKey) seller.storage.myLegalIdentityKey)
connectProtocols(buyerProtocol, sellerProtocol)
showConsensusFor(listOf(buyer, seller, notary)) showConsensusFor(listOf(buyer, seller, notary))
showProgressFor(listOf(buyer, seller)) showProgressFor(listOf(buyer, seller))
val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.$TOPIC.buyer", buyerProtocol) val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.seller", sellerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.$TOPIC.seller", sellerProtocol)
return Futures.successfulAsList(buyerFuture, sellerFuture) return Futures.successfulAsList(buyerFuture, sellerFuture)
} }
} }

View File

@ -4,22 +4,28 @@ package com.r3corda.testing
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.contracts.StateRef import com.r3corda.core.contracts.StateRef
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.protocols.HandshakeMessage import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.statemachine.StateMachineManager.Change
import com.r3corda.node.utilities.AddOrRemove.ADD
import com.r3corda.testing.node.MockIdentityService import com.r3corda.testing.node.MockIdentityService
import com.r3corda.testing.node.MockServices import com.r3corda.testing.node.MockServices
import rx.Subscriber
import java.net.ServerSocket import java.net.ServerSocket
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import kotlin.reflect.KClass
/** /**
* JAVA INTEROP * JAVA INTEROP
@ -129,22 +135,32 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail
) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) } ) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) }
/** /**
* Connect two protocols together for communication. Both protocols must have a property called otherParty of type Party * The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty
* which points to the other party in the communication. * protocol requests for it using [markerClass].
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request.
*/ */
fun connectProtocols(protocol1: ProtocolLogic<*>, protocol2: ProtocolLogic<*>) { inline fun <R, reified P : ProtocolLogic<R>> AbstractNode.initiateSingleShotProtocol(
markerClass: KClass<*>,
noinline protocolFactory: (Party) -> P): ListenableFuture<ProtocolStateMachine<R>> {
services.registerProtocolInitiator(markerClass, protocolFactory)
data class Handshake(override val replyToParty: Party, val future = SettableFuture.create<ProtocolStateMachine<R>>()
override val sendSessionID: Long,
override val receiveSessionID: Long) : HandshakeMessage
val sessionId1 = random63BitValue() val subscriber = object : Subscriber<Change>() {
val sessionId2 = random63BitValue() override fun onNext(change: Change) {
protocol1.registerSession(Handshake(protocol1.otherParty, sessionId1, sessionId2)) if (change.logic is P && change.addOrRemove == ADD) {
protocol2.registerSession(Handshake(protocol2.otherParty, sessionId2, sessionId1)) unsubscribe()
future.set(change.logic.psm as ProtocolStateMachine<R>)
}
}
override fun onError(e: Throwable) {
future.setException(e)
}
override fun onCompleted() {}
} }
private val ProtocolLogic<*>.otherParty: Party smm.changes.subscribe(subscriber)
get() = javaClass.getDeclaredField("otherParty").apply { isAccessible = true }.get(this) as Party
return future
}

View File

@ -131,6 +131,10 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// It is used from the network visualiser tool. // It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!! @Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? {
return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block)
}
fun send(topic: String, target: MockNode, payload: Any) { fun send(topic: String, target: MockNode, payload: Any) {
services.networkService.send(TopicSession(topic), payload, target.info.address) services.networkService.send(TopicSession(topic), payload, target.info.address)
} }