Automatic session management between two protocols, and removal of explict topics

This commit is contained in:
Shams Asari 2016-09-27 18:25:26 +01:00
parent 4da73e28c7
commit 67fdf9b2ff
54 changed files with 1055 additions and 1113 deletions

2
.gitignore vendored
View File

@ -5,7 +5,6 @@
tags
.DS_Store
*.log
*.log.gz
*.orig
# Created by .ignore support plugin (hsz.mobi)
@ -100,3 +99,4 @@ crashlytics-build.properties
# docs related
docs/virtualenv/
/logs/

View File

@ -675,8 +675,8 @@ class InterestRateSwap() : Contract {
val nextFixingOf = nextFixingOf() ?: return null
// This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects
val (instant, duration) = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay)
return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef, duration), instant)
val instant = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay).start
return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef), instant)
}
override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary)

View File

@ -46,22 +46,20 @@ import java.util.*
// and [AbstractStateReplacementProtocol].
object TwoPartyTradeProtocol {
val TOPIC = "platform.trade"
class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice")
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo(
data class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount<Currency>,
val sellerOwnerKey: PublicKey
)
class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable)
data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherParty: Party,
val notaryNode: NodeInfo,
@ -84,8 +82,6 @@ object TwoPartyTradeProtocol {
fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS)
}
override val topic: String get() = TOPIC
@Suspendable
override fun call(): SignedTransaction {
val partialTX: SignedTransaction = receiveAndCheckProposedTransaction()
@ -172,7 +168,6 @@ object TwoPartyTradeProtocol {
object SWAPPING_SIGNATURES : ProgressTracker.Step("Swapping signatures with the seller")
override val topic: String get() = TOPIC
override val progressTracker = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES)
@Suspendable
@ -197,7 +192,7 @@ object TwoPartyTradeProtocol {
@Suspendable
private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
progressTracker.currentStep = RECEIVING
// Wait for a trade request to come in on our pre-provided session ID.
// Wait for a trade request to come in from the other side
val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
progressTracker.currentStep = VERIFYING
@ -243,8 +238,8 @@ object TwoPartyTradeProtocol {
private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> {
val ptx = TransactionType.General.Builder(notary)
// Add input and output states for the movement of cash, by using the Cash contract to generate the states.
val wallet = serviceHub.vaultService.currentVault
val cashStates = wallet.statesOfType<Cash.State>()
val vault = serviceHub.vaultService.currentVault
val cashStates = vault.statesOfType<Cash.State>()
val cashSigningPubKeys = Cash().generateSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, cashStates)
// Add inputs/outputs/a command for the movement of the asset.
ptx.addInputState(tradeRequest.assetForSale)

View File

@ -1,5 +1,6 @@
package com.r3corda.core
import com.google.common.base.Function
import com.google.common.base.Throwables
import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.Futures
@ -17,8 +18,8 @@ import java.nio.file.Files
import java.nio.file.Path
import java.time.Duration
import java.time.temporal.Temporal
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executor
import java.util.concurrent.Future
import java.util.concurrent.locks.ReentrantLock
import java.util.zip.ZipInputStream
import kotlin.concurrent.withLock
@ -66,19 +67,20 @@ fun <T> ListenableFuture<T>.success(executor: Executor, body: (T) -> Unit) = the
fun <T> ListenableFuture<T>.failure(executor: Executor, body: (Throwable) -> Unit) = then(executor) {
try {
get()
} catch(e: Throwable) {
body(e)
} catch (e: ExecutionException) {
body(e.cause!!)
} catch (t: Throwable) {
body(t)
}
}
infix fun <F, T> Future<F>.map(mapper: (F) -> T): Future<T> = Futures.lazyTransform(this) { mapper(it!!) }
infix fun <T> ListenableFuture<T>.then(body: () -> Unit): ListenableFuture<T> = apply { then(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.success(body: (T) -> Unit): ListenableFuture<T> = apply { success(RunOnCallerThread, body) }
infix fun <T> ListenableFuture<T>.failure(body: (Throwable) -> Unit): ListenableFuture<T> = apply { failure(RunOnCallerThread, body) }
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block)
infix fun <F, T> ListenableFuture<F>.map(mapper: (F) -> T): ListenableFuture<T> = Futures.transform(this, Function { mapper(it!!) })
infix fun <F, T> ListenableFuture<F>.flatMap(mapper: (F) -> ListenableFuture<T>): ListenableFuture<T> = Futures.transformAsync(this) { mapper(it!!) }
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
// TODO This is not used but there's existing code that can be replaced by this
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
try {
set(block())
@ -89,6 +91,8 @@ fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): Setta
return this
}
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block)
// Simple infix function to add back null safety that the JDK lacks: timeA until timeB
infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive)

View File

@ -2,19 +2,11 @@ package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.debug
import com.r3corda.protocols.HandshakeMessage
import org.slf4j.Logger
import rx.Observable
import java.util.*
import java.util.concurrent.CompletableFuture
/**
* A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you
@ -48,23 +40,14 @@ abstract class ProtocolLogic<out T> {
*/
val serviceHub: ServiceHub get() = psm.serviceHub
/**
* The topic to use when communicating with other parties. If more than one topic is required then use sub-protocols.
* Note that this is temporary until protocol sessions are properly implemented.
*/
protected abstract val topic: String
private val sessions = HashMap<Party, Session>()
private var sessionProtocol: ProtocolLogic<*> = this
/**
* If a node receives a [HandshakeMessage] it needs to call this method on the initiated receipt protocol to enable
* communication between it and the sender protocol. Calling this method, and other initiation steps, are already
* handled by AbstractNodeService.addProtocolHandler.
* Return the marker [Class] which [party] has used to register the counterparty protocol that is to execute on the
* other side. The default implementation returns the class object of this ProtocolLogic, but any [Class] instance
* will do as long as the other side registers with it.
*/
fun registerSession(receivedHandshake: HandshakeMessage) {
// Note that the send and receive session IDs are swapped
addSession(receivedHandshake.replyToParty, receivedHandshake.receiveSessionID, receivedHandshake.sendSessionID)
}
open fun getCounterpartyMarker(party: Party): Class<*> = javaClass
// Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData<T> {
@ -73,69 +56,41 @@ abstract class ProtocolLogic<out T> {
@Suspendable
fun <T : Any> sendAndReceive(otherParty: Party, payload: Any, receiveType: Class<T>): UntrustworthyData<T> {
val sendSessionId = getSendSessionId(otherParty, payload)
val receiveSessionId = getReceiveSessionId(otherParty)
return psm.sendAndReceive(topic, otherParty, sendSessionId, receiveSessionId, payload, receiveType)
return psm.sendAndReceive(otherParty, payload, receiveType, sessionProtocol)
}
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(otherParty, T::class.java)
@Suspendable
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>): UntrustworthyData<T> {
return psm.receive(topic, getReceiveSessionId(otherParty), receiveType)
return psm.receive(otherParty, receiveType, sessionProtocol)
}
@Suspendable
fun send(otherParty: Party, payload: Any) {
psm.send(topic, otherParty, getSendSessionId(otherParty, payload), payload)
psm.send(otherParty, payload, sessionProtocol)
}
private fun addSession(party: Party, sendSesssionId: Long, receiveSessionId: Long) {
if (party in sessions) {
logger.debug { "Existing session with party $party to be overwritten by new one" }
}
sessions[party] = Session(sendSesssionId, receiveSessionId)
}
private fun getSendSessionId(otherParty: Party, payload: Any): Long {
return if (payload is HandshakeMessage) {
addSession(otherParty, payload.sendSessionID, payload.receiveSessionID)
DEFAULT_SESSION_ID
} else {
sessions[otherParty]?.sendSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
}
private fun getReceiveSessionId(otherParty: Party): Long {
return sessions[otherParty]?.receiveSessionId ?:
throw IllegalStateException("Session with party $otherParty hasn't been established yet")
}
/**
* Check if we already have a session with this party
*/
protected fun hasSession(otherParty: Party) = sessions.containsKey(otherParty)
/**
* Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the
* [ProtocolStateMachine] and then calling the [call] method.
* @param inheritParentSessions In certain situations the subprotocol needs to inherit and use the same open
* sessions of the parent. However in most cases this is not desirable as it prevents the subprotocol from
* communicating with the same party on a different topic. For this reason the default value is false.
* @param shareParentSessions In certain situations the need arises to use the same sessions the parent protocol has
* already established. However this also prevents the subprotocol from creating new sessions with those parties.
* For this reason the default value is false.
*/
@JvmOverloads
// TODO Rethink the default value for shareParentSessions
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
@Suspendable
fun <R> subProtocol(subLogic: ProtocolLogic<R>, inheritParentSessions: Boolean = false): R {
fun <R> subProtocol(subLogic: ProtocolLogic<R>, shareParentSessions: Boolean = false): R {
subLogic.psm = psm
if (inheritParentSessions) {
subLogic.sessions.putAll(sessions)
}
maybeWireUpProgressTracking(subLogic)
val r = subLogic.call()
if (shareParentSessions) {
subLogic.sessionProtocol = this
}
val result = subLogic.call()
// It's easy to forget this when writing protocols so we just step it to the DONE state when it completes.
subLogic.progressTracker?.currentStep = ProgressTracker.DONE
return r
return result
}
private fun maybeWireUpProgressTracking(subLogic: ProtocolLogic<*>) {
@ -166,12 +121,11 @@ abstract class ProtocolLogic<out T> {
@Suspendable
abstract fun call(): T
private data class Session(val sendSessionId: Long, val receiveSessionId: Long)
// TODO this is not threadsafe, needs an atomic get-step-and-subscribe
fun track(): Pair<String, Observable<String>>? {
return progressTracker?.let {
Pair(it.currentStep.toString(), it.changes.map { it.toString() })
}
}
}

View File

@ -1,7 +1,6 @@
package com.r3corda.core.protocols
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.ServiceHub
@ -10,9 +9,12 @@ import org.slf4j.Logger
import java.util.*
data class StateMachineRunId private constructor(val uuid: UUID) {
companion object {
fun createRandom(): StateMachineRunId = StateMachineRunId(UUID.randomUUID())
}
override fun toString(): String = "${javaClass.simpleName}($uuid)"
}
/**
@ -20,18 +22,16 @@ data class StateMachineRunId private constructor(val uuid: UUID) {
*/
interface ProtocolStateMachine<R> {
@Suspendable
fun <T : Any> sendAndReceive(topic: String,
destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
fun <T : Any> sendAndReceive(otherParty: Party,
payload: Any,
receiveType: Class<T>): UntrustworthyData<T>
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T>
@Suspendable
fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T>
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>, sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T>
@Suspendable
fun send(topic: String, destination: Party, sessionID: Long, payload: Any)
fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>)
val serviceHub: ServiceHub
val logger: Logger
@ -41,3 +41,5 @@ interface ProtocolStateMachine<R> {
/** This future will complete when the call method returns. */
val resultFuture: ListenableFuture<R>
}
class ProtocolSessionException(message: String) : Exception(message)

View File

@ -9,7 +9,6 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
@ -35,10 +34,6 @@ abstract class AbstractStateReplacementProtocol<T> {
val stx: SignedTransaction
}
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
abstract class Instigator<out S : ContractState, T>(val originalState: StateAndRef<S>,
val modification: T,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<StateAndRef<S>>() {
@ -94,11 +89,6 @@ abstract class AbstractStateReplacementProtocol<T> {
private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey {
val proposal = assembleProposal(originalState.ref, modification, stx)
// TODO: Move this into protocol logic as a func on the lines of handshake(Party, HandshakeMessage)
if (!hasSession(party)) {
send(party, Handshake(serviceHub.storageService.myLegalIdentity))
}
val response = sendAndReceive<Result>(party, proposal)
val participantSignature = response.unwrap {
if (it.sig == null) throw StateReplacementException(it.error!!)

View File

@ -5,7 +5,6 @@ import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction
@ -25,31 +24,17 @@ import com.r3corda.core.transactions.SignedTransaction
class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction,
val events: Set<ClientToServiceCommand>,
val participants: Set<Party>) : ProtocolLogic<Unit>() {
companion object {
/** Topic for messages notifying a node of a new transaction */
val TOPIC = "platform.wallet.notify_tx"
}
override val topic: String = TOPIC
data class NotifyTxRequestMessage(val tx: SignedTransaction,
val events: Set<ClientToServiceCommand>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class NotifyTxRequest(val tx: SignedTransaction, val events: Set<ClientToServiceCommand>)
@Suspendable
override fun call() {
// Record it locally
serviceHub.recordTransactions(notarisedTransaction)
// TODO: Messaging layer should handle this broadcast for us (although we need to not be sending
// session ID, for that to work, as well).
// TODO: Messaging layer should handle this broadcast for us
val msg = NotifyTxRequest(notarisedTransaction, events)
participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant ->
val msg = NotifyTxRequestMessage(
notarisedTransaction,
events,
serviceHub.storageService.myLegalIdentity)
send(participant, msg)
}
}

View File

@ -14,12 +14,6 @@ import java.io.InputStream
class FetchAttachmentsProtocol(requests: Set<SecureHash>,
otherSide: Party) : FetchDataProtocol<Attachment, ByteArray>(requests, otherSide) {
companion object {
const val TOPIC = "platform.fetch.attachment"
}
override val topic: String get() = TOPIC
override fun load(txid: SecureHash): Attachment? = serviceHub.storageService.attachments.openAttachment(txid)
override fun convert(wire: ByteArray): Attachment {

View File

@ -5,7 +5,6 @@ import com.r3corda.core.contracts.NamedByHash
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch
import com.r3corda.protocols.FetchDataProtocol.HashNotFound
@ -21,8 +20,8 @@ import java.util.*
* [HashNotFound] exception being thrown.
*
* By default this class does not insert data into any local database, if you want to do that after missing items were
* fetched then override [maybeWriteToDisk]. You *must* override [load] and [queryTopic]. If the wire type is not the
* same as the ultimate type, you must also override [convert].
* fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the
* ultimate type, you must also override [convert].
*
* @param T The ultimate type of the data being fetched.
* @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type.
@ -35,10 +34,7 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
class HashNotFound(val requested: SecureHash) : BadAnswer()
class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer()
data class Request(val hashes: List<SecureHash>,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class Request(val hashes: List<SecureHash>)
data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>)
@Suspendable
@ -51,9 +47,8 @@ abstract class FetchDataProtocol<T : NamedByHash, in W : Any>(
} else {
logger.trace("Requesting ${toFetch.size} dependency(s) for verification")
val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity)
// TODO: Support "large message" response streaming so response sizes are not limited by RAM.
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, fetchReq)
val maybeItems = sendAndReceive<ArrayList<W?>>(otherSide, Request(toFetch))
// Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(maybeItems, toFetch)
maybeWriteToDisk(downloaded)

View File

@ -1,8 +1,8 @@
package com.r3corda.protocols
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.transactions.SignedTransaction
/**
* Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them.
@ -15,11 +15,5 @@ import com.r3corda.core.crypto.SecureHash
class FetchTransactionsProtocol(requests: Set<SecureHash>, otherSide: Party) :
FetchDataProtocol<SignedTransaction, SignedTransaction>(requests, otherSide) {
companion object {
const val TOPIC = "platform.fetch.tx"
}
override val topic: String get() = TOPIC
override fun load(txid: SecureHash): SignedTransaction? = serviceHub.storageService.validatedTransactions.getTransaction(txid)
}

View File

@ -31,9 +31,6 @@ class FinalityProtocol(val transaction: SignedTransaction,
fun tracker() = ProgressTracker(NOTARISING, BROADCASTING)
}
override val topic: String
get() = throw UnsupportedOperationException()
@Suspendable
override fun call() {
// TODO: Resolve the tx here: it's probably already been done, but re-resolution is a no-op and it'll make the API more forgiving.

View File

@ -24,8 +24,6 @@ import java.security.PublicKey
*/
object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
val TOPIC = "platform.notary.change"
data class Proposal(override val stateRef: StateRef,
override val modification: Party,
override val stx: SignedTransaction) : AbstractStateReplacementProtocol.Proposal<Party>
@ -35,8 +33,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Instigator<T, Party>(originalState, newNotary, progressTracker) {
override val topic: String get() = TOPIC
override fun assembleProposal(stateRef: StateRef, modification: Party, stx: SignedTransaction): AbstractStateReplacementProtocol.Proposal<Party>
= Proposal(stateRef, modification, stx)
@ -56,8 +52,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol<Party>() {
override val progressTracker: ProgressTracker = tracker())
: AbstractStateReplacementProtocol.Acceptor<Party>(otherSide) {
override val topic: String get() = TOPIC
/**
* Check the notary change proposal.
*

View File

@ -5,12 +5,10 @@ import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SignedData
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessException
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
@ -21,8 +19,6 @@ import java.security.PublicKey
object NotaryProtocol {
val TOPIC = "platform.notary"
/**
* A protocol to be used for obtaining a signature from a [NotaryService] ascertaining the transaction
* timestamp is correct and none of its inputs have been used in another completed transaction.
@ -30,8 +26,8 @@ object NotaryProtocol {
* @throws NotaryException in case the any of the inputs to the transaction have been consumed
* by another transaction or the timestamp is invalid.
*/
class Client(private val stx: SignedTransaction,
override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
open class Client(private val stx: SignedTransaction,
override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
companion object {
@ -42,8 +38,6 @@ object NotaryProtocol {
fun tracker() = ProgressTracker(REQUESTING, VALIDATING)
}
override val topic: String get() = TOPIC
lateinit var notaryParty: Party
@Suspendable
@ -51,9 +45,9 @@ object NotaryProtocol {
progressTracker.currentStep = REQUESTING
val wtx = stx.tx
notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary")
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" }
sendAndReceive<Ack>(notaryParty, Handshake(serviceHub.storageService.myLegalIdentity))
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) {
"Input states must have the same Notary"
}
val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity)
val response = sendAndReceive<Result>(notaryParty, request)
@ -80,6 +74,10 @@ object NotaryProtocol {
}
}
class ValidatingClient(stx: SignedTransaction) : Client(stx)
/**
* Checks that the timestamp command is valid (if present) and commits the input state, or returns a conflict
* if any of the input states have been previously committed.
@ -92,11 +90,9 @@ object NotaryProtocol {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : ProtocolLogic<Unit>() {
override val topic: String get() = TOPIC
@Suspendable
override fun call() {
val (stx, reqIdentity) = sendAndReceive<SignRequest>(otherSide, Ack).unwrap { it }
val (stx, reqIdentity) = receive<SignRequest>(otherSide).unwrap { it }
val wtx = stx.tx
val result = try {
@ -148,10 +144,6 @@ object NotaryProtocol {
}
}
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
/** TODO: The caller must authenticate instead of just specifying its identity */
data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party)
@ -162,23 +154,10 @@ object NotaryProtocol {
}
}
interface Factory {
fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service
}
object DefaultFactory : Factory {
override fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): Service {
return Service(otherSide, timestampChecker, uniquenessProvider)
}
}
}
class NotaryException(val error: NotaryError) : Exception() {
override fun toString() = "${super.toString()}: Error response from Notary - ${error.toString()}"
override fun toString() = "${super.toString()}: Error response from Notary - $error"
}
sealed class NotaryError {

View File

@ -6,7 +6,6 @@ import com.r3corda.core.contracts.FixOf
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
@ -34,8 +33,6 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
companion object {
val TOPIC = "platform.rates.interest.fix"
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
object WORKING : ProgressTracker.Step("Working with data returned by oracle")
object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle")
@ -43,31 +40,22 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING)
}
override val topic: String get() = TOPIC
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>,
val deadline: Instant,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class SignRequest(val tx: WireTransaction,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class QueryRequest(val queries: List<FixOf>, val deadline: Instant)
data class SignRequest(val tx: WireTransaction)
@Suspendable
override fun call() {
progressTracker.currentStep = progressTracker.steps[1]
val fix = query()
val fix = subProtocol(FixQueryProtocol(fixOf, oracle))
progressTracker.currentStep = WORKING
checkFixIsNearExpected(fix)
tx.addCommand(fix, oracle.owningKey)
beforeSigning(fix)
progressTracker.currentStep = SIGNING
tx.addSignatureUnchecked(sign())
val signature = subProtocol(FixSignProtocol(tx, oracle))
tx.addSignatureUnchecked(signature)
}
/**
@ -86,31 +74,36 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
}
}
@Suspendable
private fun sign(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, req)
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
class FixQueryProtocol(val fixOf: FixOf, val oracle: Party) : ProtocolLogic<Fix>() {
@Suspendable
override fun call(): Fix {
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
// TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, QueryRequest(listOf(fixOf), deadline))
return resp.unwrap {
val fix = it.first()
// Check the returned fix is for what we asked for.
check(fix.of == fixOf)
fix
}
}
}
@Suspendable
private fun query(): Fix {
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
val req = QueryRequest(listOf(fixOf), deadline, serviceHub.storageService.myLegalIdentity)
// TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, req)
return resp.unwrap {
val fix = it.first()
// Check the returned fix is for what we asked for.
check(fix.of == fixOf)
fix
class FixSignProtocol(val tx: TransactionBuilder, val oracle: Party) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
@Suspendable
override fun call(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, SignRequest(wtx))
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
}
}
}
}

View File

@ -127,8 +127,6 @@ class ResolveTransactionsProtocol(private val txHashes: Set<SecureHash>,
return result
}
override val topic: String get() = throw UnsupportedOperationException()
@Suspendable
private fun downloadDependencies(depsToCheck: Set<SecureHash>): Collection<SignedTransaction> {
// Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth

View File

@ -32,19 +32,4 @@ interface PartyRequestMessage : ServiceRequestMessage {
override fun getReplyTo(networkMapCache: NetworkMapCache): MessageRecipients {
return networkMapCache.partyNodes.single { it.identity == replyToParty }.address
}
}
/**
* A Handshake message is sent to initiate communication between two protocol instances. It contains the two session IDs
* the two protocols will need to communicate.
* Note: This is a temperary interface and will be removed once the protocol session work is implemented.
*/
interface HandshakeMessage : PartyRequestMessage {
val sendSessionID: Long
val receiveSessionID: Long
@Deprecated("sessionID functions as receiveSessionID but it's recommended to use the later for clarity",
replaceWith = ReplaceWith("receiveSessionID"))
override val sessionID: Long get() = receiveSessionID
}

View File

@ -6,12 +6,10 @@ import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.crypto.toBase58String
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.TransactionBuilder
@ -22,7 +20,6 @@ import com.r3corda.core.utilities.trace
import java.math.BigDecimal
import java.security.KeyPair
import java.security.PublicKey
import java.time.Duration
/**
* Classes for manipulating a two party deal or agreement.
@ -36,10 +33,6 @@ import java.time.Duration
*/
object TwoPartyDealProtocol {
val DEAL_TOPIC = "platform.deal"
/** This topic exists purely for [FixingSessionInitiation] to be sent from [FixingRoleDecider] to [FixingSessionInitiationHandler] */
val FIX_INITIATE_TOPIC = "platform.fix.initiate"
class DealMismatchException(val expectedDeal: ContractState, val actualDeal: ContractState) : Exception() {
override fun toString() = "The submitted deal didn't match the expected: $expectedDeal vs $actualDeal"
}
@ -53,13 +46,19 @@ object TwoPartyDealProtocol {
class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable)
/**
* [Primary] at the end sends the signed tx to all the regulator parties. This a seperate workflow which needs a
* sepearate session with the regulator. This interface is used to do that in [Primary.getCounterpartyMarker].
*/
interface MarkerForBogusRegulatorProtocol
/**
* Abstracted bilateral deal protocol participant that initiates communication/handshake.
*
* There's a good chance we can push at least some of this logic down into core protocol logic
* and helper methods etc.
*/
abstract class Primary<out U>(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic<SignedTransaction>() {
abstract class Primary(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic<SignedTransaction>() {
companion object {
object AWAITING_PROPOSAL : ProgressTracker.Step("Handshaking and awaiting transaction proposal")
@ -73,13 +72,19 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS, RECORDING, COPYING_TO_REGULATOR)
}
override val topic: String get() = DEAL_TOPIC
abstract val payload: U
abstract val payload: Any
abstract val notaryNode: NodeInfo
abstract val otherParty: Party
abstract val myKeyPair: KeyPair
override fun getCounterpartyMarker(party: Party): Class<*> {
return if (serviceHub.networkMapCache.regulators.any { it.identity == party }) {
MarkerForBogusRegulatorProtocol::class.java
} else {
super.getCounterpartyMarker(party)
}
}
@Suspendable
fun getPartialTransaction(): UntrustworthyData<SignedTransaction> {
progressTracker.currentStep = AWAITING_PROPOSAL
@ -199,8 +204,6 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES, RECORDING)
}
override val topic: String get() = DEAL_TOPIC
abstract val otherParty: Party
@Suspendable
@ -234,9 +237,7 @@ object TwoPartyDealProtocol {
val handshake = receive<Handshake<U>>(otherParty)
progressTracker.currentStep = VERIFYING
handshake.unwrap {
return validateHandshake(it)
}
return handshake.unwrap { validateHandshake(it) }
}
@Suspendable
@ -263,47 +264,45 @@ object TwoPartyDealProtocol {
@Suspendable protected abstract fun assembleSharedTX(handshake: Handshake<U>): Pair<TransactionBuilder, List<PublicKey>>
}
data class AutoOffer(val notary: Party, val dealBeingOffered: DealState)
/**
* One side of the protocol for inserting a pre-agreed deal.
*/
open class Instigator<out T : DealState>(override val otherParty: Party,
val notary: Party,
override val payload: T,
override val myKeyPair: KeyPair,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<T>() {
open class Instigator(override val otherParty: Party,
override val payload: AutoOffer,
override val myKeyPair: KeyPair,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
override val notaryNode: NodeInfo get() =
serviceHub.networkMapCache.notaryNodes.filter { it.identity == notary }.single()
serviceHub.networkMapCache.notaryNodes.filter { it.identity == payload.notary }.single()
}
/**
* One side of the protocol for inserting a pre-agreed deal.
*/
open class Acceptor<T : DealState>(override val otherParty: Party,
val notary: Party,
val dealToBuy: T,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<T>() {
open class Acceptor(override val otherParty: Party,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<AutoOffer>() {
override fun validateHandshake(handshake: Handshake<T>): Handshake<T> {
override fun validateHandshake(handshake: Handshake<AutoOffer>): Handshake<AutoOffer> {
// What is the seller trying to sell us?
val deal: T = handshake.payload
logger.trace { "Got deal request for: ${handshake.payload.ref}" }
check(dealToBuy == deal)
return handshake.copy(payload = deal)
val autoOffer = handshake.payload
val deal = autoOffer.dealBeingOffered
logger.trace { "Got deal request for: ${deal.ref}" }
return handshake.copy(payload = autoOffer.copy(dealBeingOffered = deal))
}
override fun assembleSharedTX(handshake: Handshake<T>): Pair<TransactionBuilder, List<PublicKey>> {
val ptx = handshake.payload.generateAgreement(notary)
override fun assembleSharedTX(handshake: Handshake<AutoOffer>): Pair<TransactionBuilder, List<PublicKey>> {
val deal = handshake.payload.dealBeingOffered
val ptx = deal.generateAgreement(handshake.payload.notary)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one.
ptx.setTime(serviceHub.clock.instant(), 30.seconds)
return Pair(ptx, arrayListOf(handshake.payload.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey))
return Pair(ptx, arrayListOf(deal.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey))
}
}
/**
@ -314,16 +313,15 @@ object TwoPartyDealProtocol {
* who does what in the protocol.
*/
class Fixer(override val otherParty: Party,
val oracleType: ServiceType,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<StateRef>() {
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<FixingSession>() {
private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState
override fun validateHandshake(handshake: Handshake<StateRef>): Handshake<StateRef> {
override fun validateHandshake(handshake: Handshake<FixingSession>): Handshake<FixingSession> {
logger.trace { "Got fixing request for: ${handshake.payload}" }
txState = serviceHub.loadState(handshake.payload)
txState = serviceHub.loadState(handshake.payload.ref)
deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it.
@ -336,7 +334,7 @@ object TwoPartyDealProtocol {
}
@Suspendable
override fun assembleSharedTX(handshake: Handshake<StateRef>): Pair<TransactionBuilder, List<PublicKey>> {
override fun assembleSharedTX(handshake: Handshake<FixingSession>): Pair<TransactionBuilder, List<PublicKey>> {
@Suppress("UNCHECKED_CAST")
val fixOf = deal.nextFixingOf()!!
@ -348,12 +346,12 @@ object TwoPartyDealProtocol {
val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(oracleType).first()
val oracle = serviceHub.networkMapCache.get(handshake.payload.oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable
override fun beforeSigning(fix: Fix) {
newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload), fix)
newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload.ref), fix)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one.
@ -373,12 +371,13 @@ object TwoPartyDealProtocol {
* does what in the protocol.
*/
class Floater(override val otherParty: Party,
override val payload: StateRef,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary<StateRef>() {
override val payload: FixingSession,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
@Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
val state = serviceHub.loadState(payload) as TransactionState<FixableDealState>
StateAndRef(state, payload)
val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload.ref)
}
override val myKeyPair: KeyPair get() {
@ -393,23 +392,18 @@ object TwoPartyDealProtocol {
/** Used to set up the session between [Floater] and [Fixer] */
data class FixingSessionInitiation(val timeout: Duration,
val oracleType: ServiceType,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class FixingSession(val ref: StateRef, val oracleType: ServiceType)
/**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
*
* It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the
* Fixer role is chosen, then that will be initiated by the [FixingSessionInitiation] message sent from the other party and
* Fixer role is chosen, then that will be initiated by the [FixingSession] message sent from the other party and
* handled by the [FixingSessionInitiationHandler].
*
* TODO: Replace [FixingSessionInitiation] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists.
* TODO: Replace [FixingSession] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists.
*/
class FixingRoleDecider(val ref: StateRef,
val timeout: Duration,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object {
@ -418,8 +412,6 @@ object TwoPartyDealProtocol {
fun tracker() = ProgressTracker(LOADING())
}
override val topic: String get() = FIX_INITIATE_TOPIC
@Suspendable
override fun call(): Unit {
progressTracker.nextStep()
@ -427,17 +419,10 @@ object TwoPartyDealProtocol {
// TODO: this is not the eventual mechanism for identifying the parties
val fixableDeal = (dealToFix.data as FixableDealState)
val sortedParties = fixableDeal.parties.sortedBy { it.name }
val oracleType = fixableDeal.oracleType
if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) {
val initation = FixingSessionInitiation(
timeout,
oracleType,
serviceHub.storageService.myLegalIdentity)
// Send initiation to other side to launch one side of the fixing protocol (the Fixer).
send(sortedParties[1], initation)
// Then start the other side of the fixing protocol.
subProtocol(Floater(sortedParties[1], ref), inheritParentSessions = true)
val fixing = FixingSession(ref, fixableDeal.oracleType)
// Start the Floater which will then kick-off the Fixer
subProtocol(Floater(sortedParties[1], fixing))
}
}
}

View File

@ -1,10 +1,11 @@
package com.r3corda.core.protocols;
import org.junit.Test;
import org.jetbrains.annotations.*;
import org.junit.*;
import java.util.*;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class ProtocolLogicRefFromJavaTest {
@ -33,12 +34,6 @@ public class ProtocolLogicRefFromJavaTest {
public Void call() {
return null;
}
@NotNull
@Override
protected String getTopic() {
throw new UnsupportedOperationException();
}
}
private static class JavaNoArgProtocolLogic extends ProtocolLogic<Void> {
@ -50,12 +45,6 @@ public class ProtocolLogicRefFromJavaTest {
public Void call() {
return null;
}
@NotNull
@Override
protected String getTopic() {
throw new UnsupportedOperationException();
}
}
@Test

View File

@ -10,28 +10,21 @@ import com.pholser.junit.quickcheck.runner.JUnitQuickcheck
import com.r3corda.contracts.testing.SignedTransactionGenerator
import com.r3corda.core.serialization.createKryo
import com.r3corda.core.serialization.serialize
import com.r3corda.core.testing.PartyGenerator
import com.r3corda.protocols.BroadcastTransactionProtocol
import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
import org.junit.runner.RunWith
import kotlin.test.assertEquals
@RunWith(JUnitQuickcheck::class)
class BroadcastTransactionProtocolTest {
class NotifyTxRequestMessageGenerator : Generator<BroadcastTransactionProtocol.NotifyTxRequestMessage>(BroadcastTransactionProtocol.NotifyTxRequestMessage::class.java) {
override fun generate(random: SourceOfRandomness, status: GenerationStatus): BroadcastTransactionProtocol.NotifyTxRequestMessage {
return BroadcastTransactionProtocol.NotifyTxRequestMessage(
tx = SignedTransactionGenerator().generate(random, status),
events = setOf(),
replyToParty = PartyGenerator().generate(random, status),
sendSessionID = random.nextLong(),
receiveSessionID = random.nextLong()
)
class NotifyTxRequestMessageGenerator : Generator<NotifyTxRequest>(NotifyTxRequest::class.java) {
override fun generate(random: SourceOfRandomness, status: GenerationStatus): NotifyTxRequest {
return NotifyTxRequest(tx = SignedTransactionGenerator().generate(random, status), events = setOf())
}
}
@Property
fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: BroadcastTransactionProtocol.NotifyTxRequestMessage) {
fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: NotifyTxRequest) {
val kryo = createKryo()
val serialized = message.serialize().bits
val deserialized = kryo.readClassAndObject(Input(serialized))

View File

@ -23,18 +23,15 @@ class ProtocolLogicRefTest {
constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b"))
override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
}
class KotlinNoArgProtocolLogic : ProtocolLogic<Unit>() {
override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
}
@Suppress("UNUSED_PARAMETER") // We will never use A or b
class NotWhiteListedKotlinProtocolLogic(A: Int, b: String) : ProtocolLogic<Unit>() {
override fun call() = Unit
override val topic: String get() = throw UnsupportedOperationException()
}
lateinit var factory: ProtocolLogicRefFactory

View File

@ -91,7 +91,7 @@ Our protocol has two parties (B and S for buyer and seller) and will proceed as
it lacks a signature from S authorising movement of the asset.
3. S signs it and hands the now finalised ``SignedTransaction`` back to B.
You can find the implementation of this protocol in the file ``contracts/protocols/TwoPartyTradeProtocol.kt``.
You can find the implementation of this protocol in the file ``contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt``.
Assuming no malicious termination, they both end the protocol being in posession of a valid, signed transaction that
represents an atomic asset swap.
@ -110,7 +110,6 @@ each side.
.. sourcecode:: kotlin
object TwoPartyTradeProtocol {
val TOPIC = "platform.trade"
class UnacceptablePriceException(val givenPrice: Amount<Currency>) : Exception("Unacceptable price: $givenPrice")
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
@ -118,21 +117,20 @@ each side.
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo(
data class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount,
val sellerOwnerKey: PublicKey,
val sessionID: Long
val price: Amount<Currency>,
val sellerOwnerKey: PublicKey
)
class SignaturesFromSeller(val timestampAuthoritySig: DigitalSignature.WithKey, val sellerSig: DigitalSignature.WithKey)
data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
val notarySig: DigitalSignature.LegallyIdentifiable)
open class Seller(val otherSide: Party,
val notaryNode: NodeInfo,
val assetToSell: StateAndRef<OwnableState>,
val price: Amount<Currency>,
val myKeyPair: KeyPair,
val buyerSessionID: Long,
override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
@ -143,8 +141,7 @@ each side.
open class Buyer(val otherSide: Party,
val notary: Party,
val acceptablePrice: Amount<Currency>,
val typeToBuy: Class<out OwnableState>,
val sessionID: Long) : ProtocolLogic<SignedTransaction>() {
val typeToBuy: Class<out OwnableState>) : ProtocolLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
TODO()
@ -152,25 +149,17 @@ each side.
}
}
Let's unpack what this code does:
- It defines a several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes
are simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol.
- It defines the "trade topic", which is just a string that namespaces this protocol. The prefix "platform." is reserved
by Corda, but you can define your own protocol namespaces using standard Java-style reverse DNS notation.
This code defines several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes are
simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol.
Going through the data needed to become a seller, we have:
- ``otherSide: SingleMessageRecipient`` - the network address of the node with which you are trading.
- ``otherSide: Party`` - the party with which you are trading.
- ``notaryNode: NodeInfo`` - the entry in the network map for the chosen notary. See ":doc:`consensus`" for more
information on notaries.
- ``assetToSell: StateAndRef<OwnableState>`` - a pointer to the ledger entry that represents the thing being sold.
- ``price: Amount<Currency>`` - the agreed on price that the asset is being sold for (without an issuer constraint).
- ``myKeyPair: KeyPair`` - the key pair that controls the asset being sold. It will be used to sign the transaction.
- ``buyerSessionID: Long`` - a unique number that identifies this trade to the buyer. It is expected that the buyer
knows that the trade is going to take place and has sent you such a number already.
.. note:: Session IDs will be automatically handled in a future version of the framework.
And for the buyer:
@ -178,7 +167,6 @@ And for the buyer:
a price less than or equal to this, then the trade will go ahead.
- ``typeToBuy: Class<out OwnableState>`` - the type of state that is being purchased. This is used to check that the
sell side of the protocol isn't trying to sell us the wrong thing, whether by accident or on purpose.
- ``sessionID: Long`` - the session ID that was handed to the seller in order to start the protocol.
Alright, so using this protocol shouldn't be too hard: in the simplest case we can just create a Buyer or Seller
with the details of the trade, depending on who we are. We then have to start the protocol in some way. Just
@ -221,6 +209,27 @@ protocol are checked against a whitelist, which can be extended by apps themselv
The process of starting a protocol returns a ``ListenableFuture`` that you can use to either block waiting for
the result, or register a callback that will be invoked when the result is ready.
In a two party protocol only one side is to be manually started using ``ServiceHub.invokeProtocolAsync``. The other side
has to be registered by its node to respond to the initiating protocol via ``ServiceHubInternal.registerProtocolInitiator``.
In our example it doesn't matter which protocol is the initiator and which is the initiated. For example, if we are to
take the seller as the initiator then we would register the buyer as such:
.. container:: codeset
.. sourcecode:: kotlin
val services: ServiceHubInternal = TODO()
services.registerProtocolInitiator(Seller::class) { otherParty ->
val notary = services.networkMapCache.notaryNodes[0]
val acceptablePrice = TODO()
val typeToBuy = TODO()
Buyer(otherParty, notary, acceptablePrice, typeToBuy)
}
This is telling the buyer node to fire up an instance of ``Buyer`` (the code in the lambda) when the initiating protocol
is a seller (``Seller::class``).
Implementing the seller
-----------------------
@ -253,12 +262,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method.
@Suspendable
private fun receiveAndCheckProposedTransaction(): SignedTransaction {
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID)
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, buyerSessionID, sessionID, hello)
val maybeSTX = sendAndReceive<SignedTransaction>(otherSide, hello)
maybeSTX.unwrap {
// Check that the tx proposed by the buyer is valid.
@ -281,11 +288,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method.
}
}
Let's break this down. We generate a session ID to identify what's happening on the seller side, fill out
the initial protocol message, and then call ``sendAndReceive``. This function takes a few arguments:
Let's break this down. We fill out the initial protocol message with the trade info, and then call ``sendAndReceive``.
This function takes a few arguments:
- The topic string that ensures the message is routed to the right bit of code in the other side's node.
- The session IDs that ensure the messages don't get mixed up with other simultaneous trades.
- The party on the other side.
- The thing to send. It'll be serialised and sent automatically.
- Finally a type argument, which is the kind of object we're expecting to receive from the other side. If we get
back something else an exception is thrown.
@ -370,7 +376,7 @@ Here's the rest of the code:
notarySignature: DigitalSignature.LegallyIdentifiable): SignedTransaction {
val fullySigned = partialTX + ourSignature + notarySignature
logger.trace { "Built finished transaction, sending back to secondary!" }
send(otherSide, buyerSessionID, SignaturesFromSeller(ourSignature, notarySignature))
send(otherSide, SignaturesFromSeller(ourSignature, notarySignature))
return fullySigned
}
@ -406,7 +412,7 @@ OK, let's do the same for the buyer side:
val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest)
val stx = signWithOurKeys(cashSigningPubKeys, ptx)
val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID)
val signatures = swapSignaturesWithSeller(stx)
logger.trace { "Got signatures from seller, verifying ... " }
@ -419,16 +425,14 @@ OK, let's do the same for the buyer side:
@Suspendable
private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
// Wait for a trade request to come in on our pre-provided session ID.
val maybeTradeRequest = receive<SellerTradeInfo>(sessionID)
// Wait for a trade request to come in from the other side
val maybeTradeRequest = receive<SellerTradeInfo>(otherParty)
maybeTradeRequest.unwrap {
// What is the seller trying to sell us?
val asset = it.assetForSale.state.data
val assetTypeName = asset.javaClass.name
logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" }
// Check the start message for acceptability.
check(it.sessionID > 0)
if (it.price > acceptablePrice)
throw UnacceptablePriceException(it.price)
if (!typeToBuy.isInstance(asset))
@ -443,13 +447,13 @@ OK, let's do the same for the buyer side:
}
@Suspendable
private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller {
private fun swapSignaturesWithSeller(stx: SignedTransaction): SignaturesFromSeller {
progressTracker.currentStep = SWAPPING_SIGNATURES
logger.trace { "Sending partially signed transaction to seller" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive<SignaturesFromSeller>(otherSide, theirSessionID, sessionID, stx).unwrap { it }
return sendAndReceive<SignaturesFromSeller>(otherSide, stx).unwrap { it }
}
private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
@ -676,7 +680,6 @@ Future features
The protocol framework is a key part of the platform and will be extended in major ways in future. Here are some of
the features we have planned:
* Automatic session ID management
* Identity based addressing
* Exposing progress trackers to local (inside the firewall) clients using message queues and/or WebSockets
* Exception propagation and management, with a "protocol hospital" tool to manually provide solutions to unavoidable

View File

@ -24,6 +24,7 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.debug
import com.r3corda.node.api.APIServer
import com.r3corda.node.services.api.*
import com.r3corda.node.services.config.NodeConfiguration
@ -54,8 +55,10 @@ import java.nio.file.Path
import java.security.KeyPair
import java.time.Clock
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
/**
* A base node implementation that can be customised either for production (with real implementations that do real
@ -91,6 +94,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
protected val _servicesThatAcceptUploads = ArrayList<AcceptsFileUpload>()
val servicesThatAcceptUploads: List<AcceptsFileUpload> = _servicesThatAcceptUploads
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
val services = object : ServiceHubInternal() {
override val networkService: MessagingServiceInternal get() = net
override val networkMapCache: NetworkMapCache get() = netMapCache
@ -109,6 +114,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
return smm.add(loggerName, logic).resultFuture
}
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
require(markerClass !in protocolFactories) { "${markerClass.java.name} has already been used to register a protocol" }
log.debug { "Registering ${markerClass.java.name}" }
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(storage, txs)
}

View File

@ -1,11 +1,9 @@
package com.r3corda.node.services
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange {
class Plugin : CordaPluginRegistry() {
@ -16,11 +14,9 @@ object NotaryChange {
* A service that monitors the network for requests for changing the notary of a state,
* and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met.
*/
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
NotaryChangeProtocol.Acceptor(req.replyToParty)
}
services.registerProtocolInitiator(NotaryChangeProtocol.Instigator::class) { NotaryChangeProtocol.Acceptor(it) }
}
}
}

View File

@ -1,16 +1,12 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.messaging.createMessage
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe
@ -20,10 +16,6 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService
/**
@ -68,36 +60,4 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
}
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: ProtocolLogic<R>.(ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
protocol.onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
}

View File

@ -1,7 +1,7 @@
package com.r3corda.node.services.api
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.ProtocolIORequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
/**
@ -30,14 +30,13 @@ interface CheckpointStorage {
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: ProtocolIORequest?,
val receivedPayload: Any?
) {
// This flag is always false when loaded from storage as it isn't serialised.
// It is used to track when the associated fiber has been created, but not necessarily started when
// messages for protocols arrive before the system has fully loaded at startup.
@Transient
var fiberCreated: Boolean = false
}
class Checkpoint(val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
val id: SecureHash get() = serialisedFiber.hash
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
override fun hashCode(): Int = id.hashCode()
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
}

View File

@ -1,14 +1,16 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService {
/**
@ -49,7 +51,7 @@ abstract class ServiceHubInternal : ServiceHub {
* @param txs The transactions to record.
*/
internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable<SignedTransaction>) {
val stateMachineRunId = ProtocolStateMachineImpl.retrieveCurrentStateMachine()?.id
val stateMachineRunId = ProtocolStateMachineImpl.currentStateMachine()?.id
if (stateMachineRunId != null) {
txs.forEach {
storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id)
@ -68,6 +70,23 @@ abstract class ServiceHubInternal : ServiceHub {
*/
abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T>
/**
* Register the protocol factory we wish to use when a initiating party attempts to communicate with us. The
* registration is done against a marker [KClass] which is sent in the session handsake by the other party. If this
* marker class has been registered then the corresponding factory will be used to create the protocol which will
* communicate with the other side. If there is no mapping then the session attempt is rejected.
* @param markerClass The marker [KClass] present in a session initiation attempt, which is a 1:1 mapping to a [Class]
* using the <pre>::class</pre> construct. Any marker class can be used, with the default being the class of the initiating
* protocol. This enables the registration to be of the form: registerProtocolInitiator(InitiatorProtocol::class, ::InitiatedProtocol)
* @param protocolFactory The protocol factory generating the initiated protocol.
*/
abstract fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>)
/**
* Return the protocol factory that has been registered with [markerClass], or null if no factory is found.
*/
abstract fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)?
override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> {
val logicRef = protocolLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST")

View File

@ -1,28 +1,25 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
import com.r3corda.protocols.TwoPartyDealProtocol.Floater
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
* running scheduled fixings for the [InterestRateSwap] contract.
*
* TODO: This will be replaced with the automatic sessionID / session setup work.
* TODO: This will be replaced with the symmetric session work
*/
object FixingSessionInitiation {
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
}
services.registerProtocolInitiator(Floater::class) { Fixer(it) }
}
}
}

View File

@ -169,7 +169,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash payment transaction generated"
)
@ -203,7 +203,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), participants)
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash destruction transaction generated"
)
@ -222,7 +222,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash issuance completed"
)

View File

@ -1,17 +1,12 @@
package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.failure
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.serialization.serialize
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.*
import java.io.InputStream
@ -39,78 +34,73 @@ object DataVending {
// TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because
// the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both
// should be fixed at the same time.
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<DataVending.Service>()
/**
* Notify a node of a transaction. Normally any notarisation required would happen before this is called.
*/
fun notify(net: MessagingService,
myIdentity: Party,
recipient: NodeInfo,
transaction: SignedTransaction) {
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
}
}
val storage = services.storageService
class TransactionRejectedError(msg: String) : Exception(msg)
init {
addMessageHandler(FetchTransactionsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
addProtocolHandler(
BroadcastTransactionProtocol.TOPIC,
"Resolving transactions",
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
ResolveTransactionsProtocol(req.tx, req.replyToParty)
},
{ future, req ->
future.success {
serviceHub.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
services.registerProtocolInitiator(FetchTransactionsProtocol::class, ::FetchTransactionsHandler)
services.registerProtocolInitiator(FetchAttachmentsProtocol::class, ::FetchAttachmentsHandler)
services.registerProtocolInitiator(BroadcastTransactionProtocol::class, ::NotifyTransactionHandler)
}
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {
require(req.hashes.isNotEmpty())
return req.hashes.map {
val tx = storage.validatedTransactions.getTransaction(it)
if (tx == null)
logger.info("Got request for unknown tx $it")
tx
private class FetchTransactionsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val txs = request.hashes.map {
val tx = serviceHub.storageService.validatedTransactions.getTransaction(it)
if (tx == null)
logger.info("Got request for unknown tx $it")
tx
}
send(otherParty, txs)
}
}
private fun handleAttachmentRequest(req: FetchDataProtocol.Request): List<ByteArray?> {
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
require(req.hashes.isNotEmpty())
return req.hashes.map {
val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
if (jar == null) {
logger.info("Got request for unknown attachment $it")
null
} else {
jar.readBytes()
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
private class FetchAttachmentsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val attachments = request.hashes.map {
val jar: InputStream? = serviceHub.storageService.attachments.openAttachment(it)?.open()
if (jar == null) {
logger.info("Got request for unknown attachment $it")
null
} else {
jar.readBytes()
}
}
send(otherParty, attachments)
}
}
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
class NotifyTransactionHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<BroadcastTransactionProtocol.NotifyTxRequest>(otherParty).unwrap { it }
subProtocol(ResolveTransactionsProtocol(request.tx, otherParty), shareParentSessions = true)
serviceHub.recordTransactions(request.tx)
}
}
}
}

View File

@ -1,53 +1,38 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
import java.util.*
import com.r3corda.node.services.statemachine.StateMachineManager.ProtocolSession
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
interface ProtocolIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot
val topic: String
val session: ProtocolSession
}
interface SendRequest : ProtocolIORequest {
val destination: Party
val payload: Any
val sendSessionID: Long
val uniqueMessageId: UUID
val message: SessionMessage
}
interface ReceiveRequest<T> : ProtocolIORequest {
interface ReceiveRequest<T : SessionMessage> : ProtocolIORequest {
val receiveType: Class<T>
val receiveSessionID: Long
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
}
data class SendAndReceive<T>(override val topic: String,
override val destination: Party,
override val payload: Any,
override val sendSessionID: Long,
override val uniqueMessageId: UUID,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
data class SendAndReceive<T : SessionMessage>(override val session: ProtocolSession,
override val message: SessionMessage,
override val receiveType: Class<T>) : SendRequest, ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly<T>(override val topic: String,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : ReceiveRequest<T> {
data class ReceiveOnly<T : SessionMessage>(override val session: ProtocolSession,
override val receiveType: Class<T>) : ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class SendOnly(override val destination: Party,
override val topic: String,
override val payload: Any,
override val sendSessionID: Long,
override val uniqueMessageId: UUID) : SendRequest {
data class SendOnly(override val session: ProtocolSession, override val message: SessionMessage) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}

View File

@ -8,16 +8,22 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.rootCause
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.createDatabaseTransaction
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.PrintWriter
import java.io.StringWriter
import java.sql.SQLException
import java.util.*
import java.util.concurrent.ExecutionException
@ -36,12 +42,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
private val loggerName: String)
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
field.isAccessible = true
field.get(null)
}
/**
* Return the current [ProtocolStateMachineImpl] or null if executing outside of one.
*/
fun currentStateMachine(): ProtocolStateMachineImpl<*>? = Strand.currentStrand() as? ProtocolStateMachineImpl<*>
}
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnSuspend: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null
@Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var _logger: Logger? = null
override val logger: Logger get() {
@ -62,18 +82,20 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
}
internal val openSessions = HashMap<Pair<ProtocolLogic<*>, Party>, ProtocolSession>()
init {
logic.psm = this
name = id.toString()
}
@Suspendable @Suppress("UNCHECKED_CAST")
@Suspendable
override fun run(): R {
createTransaction()
val result = try {
logic.call()
} catch (t: Throwable) {
actionOnEnd()
_resultFuture?.setException(t)
processException(t)
commitTransaction()
throw ExecutionException(t)
}
@ -106,56 +128,140 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
suspend(receiveRequest)
check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
receivedPayload = null
return untrustworthy
}
@Suspendable
override fun <T : Any> sendAndReceive(topic: String,
destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
override fun <T : Any> sendAndReceive(otherParty: Party,
payload: Any,
receiveType: Class<T>): UntrustworthyData<T> {
return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, UUID.randomUUID(), receiveType, sessionIDForReceive))
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive))
override fun <T : Any> receive(otherParty: Party,
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
suspend(SendOnly(destination, topic, payload, sessionID, UUID.randomUUID()))
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
sendInternal(session, sendSessionData)
}
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
val otherPartySessionId = session.otherPartySessionId
?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
return SessionData(otherPartySessionId, payload)
}
@Suspendable
private fun suspend(protocolIORequest: ProtocolIORequest) {
private fun sendInternal(session: ProtocolSession, message: SessionMessage) {
suspend(SendOnly(session, message))
}
@Suspendable
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
}
@Suspendable
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
}
@Suspendable
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
}
@Suspendable
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
openSessions[Pair(sessionProtocol, otherParty)] = session
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, serviceHub.storageService.myLegalIdentity, counterpartyProtocol)
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
return session
} else {
sessionInitResponse as SessionReject
throw ProtocolSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
}
}
@Suspendable
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
val receivedMessage = getReceivedMessage() ?: run {
// Suspend while we wait for the receive
receiveRequest.session.waitingForResponse = true
suspend(receiveRequest)
receiveRequest.session.waitingForResponse = false
getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $id $receiveRequest")
}
if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage)
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $id $receiveRequest")
}
}
@Suspendable
private fun suspend(ioRequest: ProtocolIORequest) {
commitTransaction()
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended $id on $ioRequest" }
try {
suspendAction(protocolIORequest)
actionOnSuspend(ioRequest)
} catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it.
logger.warn("Captured exception which was swallowed by Quasar", t)
actionOnEnd()
_resultFuture?.setException(t)
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
// completing the Future
processException(t)
}
}
createTransaction()
}
companion object {
/**
* Retrieves our state machine id if we are running a [ProtocolStateMachineImpl].
*/
fun retrieveCurrentStateMachine(): ProtocolStateMachineImpl<*>? {
return Strand.currentStrand() as? ProtocolStateMachineImpl<*>
private fun processException(t: Throwable) {
actionOnEnd()
_resultFuture?.setException(t)
}
internal fun resume(scheduler: FiberScheduler) {
try {
if (fromCheckpoint) {
logger.info("$id resumed from checkpoint")
fromCheckpoint = false
Fiber.unparkDeserialized(this, scheduler)
} else if (state == State.NEW) {
logger.trace { "$id started" }
start()
} else {
logger.trace { "$id resumed" }
Fiber.unpark(this, QUASAR_UNBLOCKER)
}
} catch (t: Throwable) {
logger.error("$id threw '${t.rootCause}'")
logger.trace {
val s = StringWriter()
t.rootCause.printStackTrace(PrintWriter(s))
"Stack trace of protocol error: $s"
}
}
}
}

View File

@ -3,34 +3,38 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import co.paralleluniverse.strands.Strand
import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.ThreadBox
import com.r3corda.core.abbreviate
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.*
import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.debug
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import kotlinx.support.jdk8.collections.removeIf
import org.jetbrains.exposed.sql.Database
import rx.Observable
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.PrintWriter
import java.io.StringWriter
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.ExecutionException
import javax.annotation.concurrent.ThreadSafe
@ -48,7 +52,6 @@ import javax.annotation.concurrent.ThreadSafe
* The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually
* starts them via [add].
*
* TODO: Session IDs should be set up and propagated automatically, on demand.
* TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised
* continuation is always unique?
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
@ -58,12 +61,19 @@ import javax.annotation.concurrent.ThreadSafe
* TODO: Implement stub/skel classes that provide a basic RPC framework on top of this.
*/
@ThreadSafe
class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableServices: List<Any>,
class StateMachineManager(val serviceHub: ServiceHubInternal,
tokenizableServices: List<Any>,
val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor,
val database: Database) {
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
companion object {
private val logger = loggerFor<StateMachineManager>()
internal val sessionTopic = TopicSession("platform.session")
}
val scheduler = FiberScheduler()
data class Change(
@ -95,6 +105,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private val totalStartedProtocols = metrics.counter("Protocols.Started")
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
private val openSessions = ConcurrentHashMap<Long, ProtocolSession>()
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
// Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
@ -119,6 +132,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val changes: Observable<Change>
get() = mutex.content.changesPublisher
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
}
}
fun start() {
restoreFibersFromCheckpoints()
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
@ -131,69 +155,99 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
field.isAccessible = true
field.get(null)
}
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
}
}
fun start() {
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
mutex.locked {
started = true
stateMachines.forEach { restartFiber(it.key, it.value) }
}
}
}
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
if (!checkpoint.fiberCreated) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
}
}
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
if (checkpoint.request is ReceiveRequest<*>) {
val topicSession = checkpoint.request.receiveTopicSession
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, topicSession, fiber)
private fun restoreFibersFromCheckpoints() {
mutex.locked {
checkpointStorage.checkpoints.forEach {
// If a protocol is added before start() then don't attempt to restore it
if (!stateMachines.containsValue(it)) {
val fiber = deserializeFiber(it.serialisedFiber)
initFiber(fiber)
stateMachines[fiber] = it
}
}
if (checkpoint.request is SendRequest) {
sendMessage(fiber, checkpoint.request)
}
}
private fun resumeRestoredFibers() {
mutex.locked {
started = true
stateMachines.keys.forEach { resumeRestoredFiber(it) }
}
serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg ->
executor.checkOnThread()
val sessionMessage = message.data.deserialize<SessionMessage>()
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage)
is SessionInit -> onSessionInit(sessionMessage)
}
}
}
private fun resumeRestoredFiber(fiber: ProtocolStateMachineImpl<*>) {
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
if (fiber.openSessions.values.any { it.waitingForResponse }) {
fiber.logger.info("Restored fiber pending on receive ${fiber.id}}")
} else {
resumeFiber(fiber)
}
}
private fun onExistingSessionMessage(message: ExistingSessionMessage) {
val session = openSessions[message.recipientSessionId]
if (session != null) {
session.psm.logger.trace { "${session.psm.id} received $message on $session" }
if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += message
if (session.waitingForResponse) {
updateCheckpoint(session.psm)
resumeFiber(session.psm)
}
} else {
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
executor.executeASAP {
if (checkpoint.request is SendRequest) {
sendMessage(fiber, checkpoint.request)
}
iterateStateMachine(fiber, checkpoint.receivedPayload) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, null, fiber)
}
val otherParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (otherParty != null) {
if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null)
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
} else {
logger.warn("Received a session message for unknown session: $message")
}
}
}
private fun onSessionInit(sessionInit: SessionInit) {
logger.trace { "Received $sessionInit" }
//TODO Verify the other party are who they say they are from the TLS subsystem
val otherParty = sessionInit.initiatorParty
val otherPartySessionId = sessionInit.initiatorSessionId
try {
val markerClass = Class.forName(sessionInit.protocolName)
val protocolFactory = serviceHub.getProtocolFactory(markerClass)
if (protocolFactory != null) {
val protocol = protocolFactory(otherParty)
val psm = createFiber(sessionInit.protocolName, protocol)
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(protocol, otherParty)] = session
updateCheckpoint(psm)
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
psm.logger.debug { "Starting new ${psm.id} from $sessionInit on $session" }
startFiber(psm)
} else {
logger.warn("Unknown protocol marker class in $sessionInit")
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
}
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
}
}
private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes<ProtocolStateMachineImpl<*>> {
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
val kryo = quasarKryo()
// add the map of tokens -> tokenizedServices to the kyro context
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
@ -204,7 +258,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val kryo = quasarKryo()
// put the map of token -> tokenized into the kryo context
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
return serialisedFiber.deserialize(kryo)
return serialisedFiber.deserialize(kryo).apply { fromCheckpoint = true }
}
private fun quasarKryo(): Kryo {
@ -212,70 +266,51 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
return createKryo(serializer.kryo)
}
private fun logError(e: Throwable, payload: Any?, topicSession: TopicSession?, psm: ProtocolStateMachineImpl<*>) {
psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " +
"when handling a message of type ${payload?.javaClass?.name} on queue $topicSession")
if (psm.logger.isTraceEnabled) {
val s = StringWriter()
Throwables.getRootCause(e).printStackTrace(PrintWriter(s))
psm.logger.trace("Stack trace of protocol error is: $s")
}
private fun <T> createFiber(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachineImpl<T> {
val id = StateMachineRunId.createRandom()
return ProtocolStateMachineImpl(id, logic, scheduler, loggerName).apply { initFiber(this) }
}
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
private fun initFiber(psm: ProtocolStateMachineImpl<*>) {
psm.database = database
psm.serviceHub = serviceHub
psm.suspendAction = { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
onNextSuspend(psm, request)
psm.actionOnSuspend = { ioRequest ->
updateCheckpoint(psm)
processIORequest(ioRequest)
}
psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
val finalCheckpoint = stateMachines.remove(psm)
if (finalCheckpoint != null) {
checkpointStorage.removeCheckpoint(finalCheckpoint)
}
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
}
val checkpoint = startingCheckpoint()
checkpoint.fiberCreated = true
totalStartedProtocols.inc()
mutex.locked {
stateMachines[psm] = checkpoint
totalStartedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.ADD)
}
return checkpoint
}
/**
* Kicks off a brand new state machine of the given class. It will log with the named logger.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val id = StateMachineRunId.createRandom()
val fiber = ProtocolStateMachineImpl(id, logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
val checkpoint = initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
checkpoint
}
checkpointStorage.addCheckpoint(checkpoint)
mutex.locked { // If we are not started then our checkpoint will be picked up during start
if (!started) {
return fiber
}
}
try {
executor.executeASAP {
iterateStateMachine(fiber, null) {
fiber.start()
private fun endAllFiberSessions(psm: ProtocolStateMachineImpl<*>) {
openSessions.values.removeIf { session ->
if (session.psm == psm) {
val otherPartySessionId = session.otherPartySessionId
if (otherPartySessionId != null) {
sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm)
}
recentlyClosedSessions[session.ourSessionId] = session.otherParty
true
} else {
false
}
}
}
private fun startFiber(fiber: ProtocolStateMachineImpl<*>) {
try {
resumeFiber(fiber)
} catch (e: ExecutionException) {
// There are two ways we can take exceptions in this method:
//
@ -290,17 +325,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
if (e.cause !is ExecutionException)
throw e
}
}
/**
* Kicks off a brand new state machine of the given class. It will log with the named logger.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = createFiber(loggerName, logic)
updateCheckpoint(fiber)
// If we are not started then our checkpoint will be picked up during start
mutex.locked {
if (started) {
startFiber(fiber)
}
}
return fiber
}
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: ProtocolIORequest?,
receivedPayload: Any?) {
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
val previousCheckpoint = mutex.locked {
stateMachines.put(psm, newCheckpoint)
}
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>) {
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
val newCheckpoint = Checkpoint(serializeFiber(psm))
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint)
}
@ -308,90 +355,70 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
checkpointingMeter.mark()
}
private fun iterateStateMachine(psm: ProtocolStateMachineImpl<*>,
receivedPayload: Any?,
resumeAction: (Any?) -> Unit) {
executor.checkOnThread()
psm.receivedPayload = receivedPayload
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
resumeAction(receivedPayload)
}
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, request, null)
// We have a request to do something: send, receive, or send-and-receive.
if (request is ReceiveRequest<*>) {
// Prepare a listener on the network that runs in the background thread when we receive a message.
prepareToReceiveForRequest(psm, serialisedFiber, request)
}
if (request is SendRequest) {
performSendRequest(psm, request)
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
executor.executeASAP {
psm.resume(scheduler)
}
}
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, request: ReceiveRequest<*>) {
executor.checkOnThread()
val queueID = request.receiveTopicSession
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" }
iterateOnResponse(psm, serialisedFiber, request) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, it, queueID, psm)
private fun processIORequest(ioRequest: ProtocolIORequest) {
if (ioRequest is SendRequest) {
if (ioRequest.message is SessionInit) {
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
}
sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm)
if (ioRequest !is ReceiveRequest<*>) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.psm)
}
}
}
private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
val topicSession = sendMessage(psm, request)
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: ProtocolStateMachineImpl<*>?) {
val node = serviceHub.networkMapCache.getNodeByLegalName(party.name)
?: throw IllegalArgumentException("Don't know about party $party")
val logger = psm?.logger ?: logger
logger.trace { "${psm?.id} sending $message to party $party" }
serviceHub.networkService.send(sessionTopic, message, node.address)
}
if (request is SendOnly) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
iterateStateMachine(psm, null) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, request.payload, topicSession, psm)
}
}
interface SessionMessage
interface ExistingSessionMessage: SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
override fun toString(): String {
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
}
}
private fun sendMessage(psm: ProtocolStateMachineImpl<*>, request: SendRequest): TopicSession {
val topicSession = TopicSession(request.topic, request.sendSessionID)
val payload = request.payload
psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" }
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?:
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
serviceHub.networkService.send(topicSession, payload, node.address, request.uniqueMessageId)
return topicSession
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
data class ProtocolSession(val protocol: ProtocolLogic<*>,
val otherParty: Party,
val ourSessionId: Long,
var otherPartySessionId: Long?,
@Volatile var waitingForResponse: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>()
val psm: ProtocolStateMachineImpl<*> get() = protocol.psm as ProtocolStateMachineImpl<*>
}
/**
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
* received.
*/
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: ReceiveRequest<*>,
resumeAction: (Any?) -> Unit) {
val topicSession = request.receiveTopicSession
serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg ->
// Assertion to ensure we don't execute on the wrong thread.
executor.checkOnThread()
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
// We should instead verify as we read the data that it's what we are expecting and throw as early as
// possible. We only do it this way for convenience during the prototyping stage. Note that this means
// we could simply not require the programmer to specify the expected return type at all, and catch it
// at the last moment when we do the downcast. However this would make protocol code harder to read and
// make it more difficult to migrate to a more explicit serialisation scheme later.
val payload = netMsg.data.deserialize<Any>()
check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" }
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
updateCheckpoint(psm, serialisedFiber, null, payload)
psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
iterateStateMachine(psm, payload, resumeAction)
}
}
}

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC
import kotlin.reflect.KClass
/**
* A Notary service acts as the final signer of a transaction ensuring two things:
@ -17,22 +16,18 @@ import com.r3corda.protocols.NotaryProtocol.TOPIC
*
* This is the base implementation that can be customised with specific Notary transaction commit protocol.
*/
abstract class NotaryService(services: ServiceHubInternal,
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : AbstractNodeService(services) {
abstract class NotaryService(markerClass: KClass<out NotaryProtocol.Client>, services: ServiceHubInternal) : SingletonSerializeAsToken() {
// Do not specify this as an advertised service. Use a concrete implementation.
// TODO: We do not want a service type that cannot be used. Fix the type system abuse here.
object Type : ServiceType("corda.notary")
abstract val logger: org.slf4j.Logger
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract val protocolFactory: NotaryProtocol.Factory
init {
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
}
services.registerProtocolInitiator(markerClass) { createProtocol(it) }
}
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract fun createProtocol(otherParty: Party): NotaryProtocol.Service
}

View File

@ -1,5 +1,6 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
@ -9,11 +10,13 @@ import com.r3corda.protocols.NotaryProtocol
/** A simple Notary service that does not perform transaction validation */
class SimpleNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.Client::class, services) {
object Type : ServiceType("corda.notary.simple")
override val logger = loggerFor<SimpleNotaryService>()
override val protocolFactory = NotaryProtocol.DefaultFactory
override fun createProtocol(otherParty: Party): NotaryProtocol.Service {
return NotaryProtocol.Service(otherParty, timestampChecker, uniquenessProvider)
}
}

View File

@ -11,17 +11,13 @@ import com.r3corda.protocols.ValidatingNotaryProtocol
/** A Notary service that validates the transaction chain of he submitted transaction before committing it */
class ValidatingNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.ValidatingClient::class, services) {
object Type : ServiceType("corda.notary.validating")
override val logger = loggerFor<ValidatingNotaryService>()
override val protocolFactory = object : NotaryProtocol.Factory {
override fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
}
override fun createProtocol(otherParty: Party): ValidatingNotaryProtocol {
return ValidatingNotaryProtocol(otherParty, timestampChecker, uniquenessProvider)
}
}

View File

@ -10,6 +10,7 @@ import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
@ -23,7 +24,6 @@ import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
@ -89,11 +89,11 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
assertEquals(aliceResult.get(), bobResult.get())
assertEquals(aliceResult.get(), bobPsm.get().resultFuture.get())
aliceNode.stop()
bobNode.stop()
@ -120,21 +120,19 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerFuture
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerResult
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
fun pumpAlice() = (aliceNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
fun pumpBob() = (bobNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
pumpBob()
bobNode.pumpReceive(false)
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
pumpAlice()
pumpBob()
pumpAlice()
pumpBob()
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
@ -147,7 +145,7 @@ class TwoPartyTradeProtocolTests {
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
// She will wait around until Bob comes back.
assertThat(pumpAlice()).isNotNull()
assertThat(aliceNode.pumpReceive(false)).isNotNull()
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
// that Bob was waiting on before the reboot occurred.
@ -309,16 +307,16 @@ class TwoPartyTradeProtocolTests {
val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray()))
val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second
val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode.services)
insertFakeTransactions(bobsFakeCash, bobNode.services)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
net.runNetwork() // Clear network map registration messages
val aliceTxStream = aliceNode.storage.validatedTransactions.track().second
val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second
val (bobResult, aliceResult, bobSmId, aliceSmId) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val aliceSmId = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerId
net.runNetwork()
@ -367,21 +365,20 @@ class TwoPartyTradeProtocolTests {
}
}
data class RunResult(
val buyerFuture: Future<SignedTransaction>,
val sellerFuture: Future<SignedTransaction>,
val buyerSmId: StateMachineRunId,
val sellerSmId: StateMachineRunId
private data class RunResult(
// The buyer is not created immediately, only when the seller starts running
val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
val sellerResult: Future<SignedTransaction>,
val sellerId: StateMachineRunId
)
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
}
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller)
// We start the Buyer first, as the Seller sends the first message
val buyerPsm = bobNode.smm.add("$TOPIC.buyer", buyer)
val sellerPsm = aliceNode.smm.add("$TOPIC.seller", seller)
return RunResult(buyerPsm.resultFuture, sellerPsm.resultFuture, buyerPsm.id, sellerPsm.id)
val sellerResultFuture = aliceNode.smm.add("seller", seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
@ -404,7 +401,7 @@ class TwoPartyTradeProtocolTests {
net.runNetwork() // Clear network map registration messages
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()
@ -412,7 +409,7 @@ class TwoPartyTradeProtocolTests {
if (bobError)
aliceResult.get()
else
bobResult.get()
bobPsm.get().resultFuture.get()
}
assertTrue(e.cause is TransactionVerificationException)
assertNotNull(e.cause!!.cause)
@ -506,6 +503,7 @@ class TwoPartyTradeProtocolTests {
return Pair(vault, listOf(ap))
}
class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage {
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
return delegate.track()
@ -530,4 +528,5 @@ class TwoPartyTradeProtocolTests {
data class Add(val transaction: SignedTransaction) : TxRecord
data class Get(val id: SecureHash) : TxRecord
}
}

View File

@ -2,23 +2,24 @@ package com.r3corda.node.services
import com.codahale.metrics.MetricRegistry
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.api.MonitoringService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.DataVending
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.testing.node.MockStorageService
import com.r3corda.testing.MOCK_IDENTITY_SERVICE
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.testing.node.MockStorageService
import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
@Suppress("LeakingThis")
open class MockServiceHubInternal(
@ -28,7 +29,6 @@ open class MockServiceHubInternal(
val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
val storage: TxWritableStorageService? = MockStorageService(),
val mapCache: NetworkMapCache? = MockNetworkMapCache(),
val mapService: NetworkMapService? = null,
val scheduler: SchedulerService? = null,
val overrideClock: Clock? = NodeClock(),
val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory()
@ -57,14 +57,10 @@ open class MockServiceHubInternal(
private val txStorageService: TxWritableStorageService
get() = storage ?: throw UnsupportedOperationException()
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
lateinit var smm: StateMachineManager
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
init {
if (net != null && storage != null) {
// Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener
@ -72,4 +68,18 @@ open class MockServiceHubInternal(
DataVending.Service(this)
}
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
}

View File

@ -3,7 +3,6 @@ package com.r3corda.node.services
import com.google.common.jimfs.Configuration
import com.google.common.jimfs.Jimfs
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.recordTransactions
@ -12,18 +11,16 @@ import com.r3corda.core.protocols.ProtocolLogicRef
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.testing.node.TestClock
import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.services.vault.NodeVaultService
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.testing.ALICE_KEY
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockKeyManagementService
import com.r3corda.testing.node.TestClock
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
@ -34,7 +31,9 @@ import java.nio.file.FileSystem
import java.security.PublicKey
import java.time.Clock
import java.time.Instant
import java.util.concurrent.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertTrue
class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@ -128,8 +127,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
(serviceHub as TestReference).testReference.calls += increment
(serviceHub as TestReference).testReference.countDown.countDown()
}
override val topic: String get() = throw UnsupportedOperationException()
}
class Command : TypeOnlyCommandData()

View File

@ -9,7 +9,6 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.Instigator
import com.r3corda.protocols.StateReplacementException
import com.r3corda.protocols.StateReplacementRefused
@ -49,7 +48,7 @@ class NotaryChangeTests {
val state = issueState(clientNodeA)
val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()
@ -62,7 +61,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()
@ -78,7 +77,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newEvilNotary = Party("Evil Notary", generateKeyPair().public)
val protocol = Instigator(state, newEvilNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()

View File

@ -1,17 +1,19 @@
package com.r3corda.node.services
import com.r3corda.core.contracts.Timestamp
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryError
import com.r3corda.protocols.NotaryException
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.testing.MINI_CORP_KEY
import com.r3corda.testing.node.MockNetwork
import org.junit.Before
import org.junit.Test
import java.time.Instant
@ -45,10 +47,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val signature = future.get()
signature.verifyWithECDSA(stx.txBits)
}
@ -61,10 +60,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val signature = future.get()
signature.verifyWithECDSA(stx.txBits)
}
@ -78,16 +74,13 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val error = (ex.cause as NotaryException).error
assertTrue(error is NotaryError.TimestampInvalid)
}
@Test fun `should report conflict for a duplicate transaction`() {
val stx = run {
val inputState = issueState(clientNode)
@ -98,8 +91,8 @@ class NotaryServiceTests {
val firstSpend = NotaryProtocol.Client(stx)
val secondSpend = NotaryProtocol.Client(stx)
clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.first", firstSpend)
val future = clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.second", secondSpend)
clientNode.services.startProtocol("notary.first", firstSpend)
val future = clientNode.services.startProtocol("notary.second", secondSpend)
net.runNetwork()
@ -108,4 +101,12 @@ class NotaryServiceTests {
assertEquals(notaryError.tx, stx.tx)
notaryError.conflict.verified()
}
private fun runNotaryClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol("notary-test", protocol)
net.runNetwork()
return future
}
}

View File

@ -1,8 +1,11 @@
package com.r3corda.node.services
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.Command
import com.r3corda.core.contracts.DummyContract
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.services.network.NetworkMapService
@ -44,9 +47,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runValidatingClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error
@ -64,9 +65,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runValidatingClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error
@ -75,4 +74,11 @@ class ValidatingNotaryServiceTests {
val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners
assertEquals(setOf(expectedMissingKey), missingKeys)
}
private fun runValidatingClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.ValidatingClient(stx)
val future = clientNode.services.startProtocol("notary", protocol)
net.runNetwork()
return future
}
}

View File

@ -1,13 +1,20 @@
package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Issued
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.contracts.USD
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.persistence.DataVending.Service.NotifyTransactionHandler
import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
import com.r3corda.testing.MEGA_CORP
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
@ -38,9 +45,8 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction()
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx)
network.runNetwork()
registerNode.sendNotifyTx(tx, vaultServiceNode)
// Check the transaction is in the receiving node
val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull()
@ -67,11 +73,23 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction(false)
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx)
network.runNetwork()
registerNode.sendNotifyTx(tx, vaultServiceNode)
// Check the transaction is not in the receiving node
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
}
}
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.services.registerProtocolInitiator(NotifyTxProtocol::class, ::NotifyTransactionHandler)
services.startProtocol("notify-tx", NotifyTxProtocol(walletServiceNode.info.identity, tx))
network.runNetwork()
}
private class NotifyTxProtocol(val otherParty: Party, val stx: SignedTransaction) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, NotifyTxRequest(stx, emptySet()))
}
}

View File

@ -10,12 +10,14 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
class PerFileCheckpointStorageTests {
val fileSystem = Jimfs.newFileSystem(unix())
val storeDir = fileSystem.getPath("store")
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage
@Before
@ -92,6 +94,6 @@ class PerFileCheckpointStorageTests {
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null)
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -2,14 +2,20 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -50,18 +56,18 @@ class StateMachineManagerTests {
}
@Test
fun `protocol suspended just after receiving payload`() {
val topic = "send-and-receive"
fun `protocol restarted just after receiving payload`() {
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.smm.add("test", sendProtocol)
node2.smm.add("test", receiveProtocol)
net.runNetwork()
node1.smm.add("test", SendProtocol(payload, node2.info.identity))
// We push through just enough messages to get only the SessionData sent
// TODO We should be able to give runNetwork a predicate for when to stop
net.runNetwork(2)
node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
@ -83,7 +89,7 @@ class StateMachineManagerTests {
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
@ -99,43 +105,44 @@ class StateMachineManagerTests {
@Test
fun `protocol loaded from checkpoint will respond to messages from before start`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.identity)
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `protocol with send will resend on interrupted restart`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val payload2 = random63BitValue()
var sentCount = 0
var receivedCount = 0
net.messagingNetwork.sentMessages.subscribe { if (it.message.topicSession.topic == topic) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (it.message.topicSession.topic == topic) receivedCount++ }
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val firstProtocol = PingPongProtocol(topic, node3.info.identity, payload)
val secondProtocol = PingPongProtocol(topic, node2.info.identity, payload2)
connectProtocols(firstProtocol, secondProtocol)
var secondProtocol: PingPongProtocol? = null
node3.services.registerProtocolInitiator(PingPongProtocol::class) {
val protocol = PingPongProtocol(it, payload2)
secondProtocol = protocol
protocol
}
// Kick off first send and receive
node2.smm.add("test", firstProtocol)
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.smm.findStateMachines(PingPongProtocol::class.java).single()
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Now add in the other half of the protocol. First message should get deduped. So message data stays in sync.
node3.smm.add("test", secondProtocol)
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork()
node2b.smm.executor.flush()
fut1.get()
@ -146,15 +153,66 @@ class StateMachineManagerTests {
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol.receivedPayload2, "Received payload does not match the expected second value on Node 2")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
@Test
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
node3.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue()
node1.smm.add("multiple-send", SendProtocol(payload, node2.info.identity, node3.info.identity))
net.runNetwork()
val node2Protocol = node2.getSingleProtocol<ReceiveThenSuspendProtocol>().first
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `receiving from multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val node2Payload = random63BitValue()
val node3Payload = random63BitValue()
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node2Payload, it) }
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.identity, node3.info.identity)
node1.smm.add("multiple-receive", multiReceiveProtocol)
net.runNetwork(1) // session handshaking
// have the messages arrive in reverse order of receive
node3.pumpReceive(false)
node2.pumpReceive(false)
net.runNetwork() // pump remaining messages
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
}
@Test
fun `exception thrown on other side`() {
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { ExceptionProtocol }
val future = node1.smm.add("exception", ReceiveThenSuspendProtocol(node2.info.identity)).resultFuture
net.runNetwork()
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
}
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
return transfer.message.topicSession == StateMachineManager.sessionTopic
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
val servicesArray = advertisedServices.toTypedArray()
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
stop()
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return node.smm.findStateMachines(P::class.java).single().first
return newNode.getSingleProtocol<P>().first
}
private inline fun <reified P : ProtocolLogic<*>> MockNode.getSingleProtocol(): Pair<P, ListenableFuture<*>> {
return smm.findStateMachines(P::class.java).single()
}
@ -165,8 +223,6 @@ class StateMachineManagerTests {
override fun call() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@ -177,8 +233,6 @@ class StateMachineManagerTests {
override fun doCall() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
@ -187,30 +241,37 @@ class StateMachineManagerTests {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() {
override fun call() = Unit
}
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, payload)
override fun call() = otherParties.forEach { send(it, payload) }
}
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null
init {
require(otherParties.isNotEmpty())
}
@Transient var receivedPayloads: List<Any> = emptyList()
@Suspendable
override fun doCall() {
receivedPayload = receive<Any>(otherParty).unwrap { it }
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
}
}
private class PingPongProtocol(override val topic: String, val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
private class PingPongProtocol(val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
@Transient var receivedPayload: Long? = null
@Transient var receivedPayload2: Long? = null
@ -219,7 +280,10 @@ class StateMachineManagerTests {
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
}
}
private object ExceptionProtocol : ProtocolLogic<Nothing>() {
override fun call(): Nothing = throw Exception()
}
/**

View File

@ -7,20 +7,22 @@ import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER
import com.r3corda.contracts.asset.cashBalances
import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.*
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.days
import com.r3corda.core.logElapsedTime
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.seconds
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.Emoji
import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.internal.Node
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.config.NodeConfigurationFromConfig
import com.r3corda.node.services.messaging.NodeMessagingClient
@ -28,7 +30,6 @@ import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol
import joptsimple.OptionParser
@ -210,23 +211,13 @@ private fun runBuyer(node: Node, amount: Amount<Currency>) {
// next stage in our building site, we will just auto-generate fake trades to give our nodes something to do.
//
// As the seller initiates the two-party trade protocol, here, we will be the buyer.
object : AbstractNodeService(node.services) {
init {
addProtocolHandler(DEMO_TOPIC, "demo.buyer") { handshake: TraderDemoHandshake ->
TraderDemoProtocolBuyer(handshake.replyToParty, attachmentsPath, amount)
}
}
node.services.registerProtocolInitiator(TraderDemoProtocolSeller::class) { otherParty ->
TraderDemoProtocolBuyer(otherParty, attachmentsPath, amount)
}
}
// We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic.
private data class TraderDemoHandshake(override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
private val DEMO_TOPIC = "initiate.demo.trade"
private class TraderDemoProtocolBuyer(val otherSide: Party,
private val attachmentsPath: Path,
val amount: Amount<Currency>,
@ -234,8 +225,6 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset")
override val topic: String get() = DEMO_TOPIC
@Suspendable
override fun call() {
progressTracker.currentStep = STARTING_BUY
@ -248,7 +237,7 @@ private class TraderDemoProtocolBuyer(val otherSide: Party,
CommercialPaper.State::class.java)
// This invokes the trading protocol and out pops our finished transaction.
val tradeTX: SignedTransaction = subProtocol(buyer, inheritParentSessions = true)
val tradeTX: SignedTransaction = subProtocol(buyer, shareParentSessions = true)
// TODO: This should be moved into the protocol itself.
serviceHub.recordTransactions(listOf(tradeTX))
@ -289,8 +278,6 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
companion object {
val PROSPECTUS_HASH = SecureHash.parse("decd098666b9657314870e192ced0c3519c2c9d395507a238338f8d003929de9")
object ANNOUNCING : ProgressTracker.Step("Announcing to the buyer node")
object SELF_ISSUING : ProgressTracker.Step("Got session ID back, issuing and timestamping some commercial paper")
object TRADING : ProgressTracker.Step("Starting the trade protocol") {
@ -300,17 +287,11 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
// We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some
// point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being
// surprised when it appears as a new set of tasks below the current one.
fun tracker() = ProgressTracker(ANNOUNCING, SELF_ISSUING, TRADING)
fun tracker() = ProgressTracker(SELF_ISSUING, TRADING)
}
override val topic: String get() = DEMO_TOPIC
@Suspendable
override fun call(): SignedTransaction {
progressTracker.currentStep = ANNOUNCING
send(otherSide, TraderDemoHandshake(serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = SELF_ISSUING
val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0]
@ -326,7 +307,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
amount,
cpOwnerKey,
progressTracker.getChildProgressTracker(TRADING)!!)
val tradeTX: SignedTransaction = subProtocol(seller, inheritParentSessions = true)
val tradeTX: SignedTransaction = subProtocol(seller, shareParentSessions = true)
serviceHub.recordTransactions(listOf(tradeTX))
return tradeTX

View File

@ -12,16 +12,14 @@ import com.r3corda.core.math.InterpolatorFactory
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.AcceptsFileUpload
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.FiberBox
import com.r3corda.protocols.RatesFixProtocol
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.protocols.RatesFixProtocol.*
import com.r3corda.protocols.TwoPartyDealProtocol
import org.slf4j.LoggerFactory
import java.io.InputStream
import java.math.BigDecimal
import java.security.KeyPair
@ -55,46 +53,31 @@ object NodeInterestRates {
/**
* The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing.
*/
class Service(services: ServiceHubInternal) : AcceptsFileUpload, AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : AcceptsFileUpload, SingletonSerializeAsToken() {
val ss = services.storageService
val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey, services.clock)
private val logger = LoggerFactory.getLogger(Service::class.java)
init {
addMessageHandler(RatesFixProtocol.TOPIC,
{ req: ServiceRequestMessage ->
if (req is RatesFixProtocol.SignRequest) {
oracle.sign(req.tx)
}
else {
/**
* We put this into a protocol so that if it blocks waiting for the interest rate to become
* available, we a) don't block this thread and b) allow the fact we are waiting
* to be persisted/checkpointed.
* Interest rates become available when they are uploaded via the web as per [DataUploadServlet],
* if they haven't already been uploaded that way.
*/
req as RatesFixProtocol.QueryRequest
val handler = FixQueryHandler(this, req)
handler.registerSession(req)
services.startProtocol("fixing", handler)
Unit
}
},
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
)
services.registerProtocolInitiator(FixSignProtocol::class) { FixSignHandler(it, oracle) }
services.registerProtocolInitiator(FixQueryProtocol::class) { FixQueryHandler(it, oracle) }
}
private class FixQueryHandler(val service: Service,
val request: RatesFixProtocol.QueryRequest) : ProtocolLogic<Unit>() {
private class FixSignHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<SignRequest>(otherParty).unwrap { it }
send(otherParty, oracle.sign(request.tx))
}
}
private class FixQueryHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic<Unit>() {
companion object {
object RECEIVED : ProgressTracker.Step("Received fix request")
object SENDING : ProgressTracker.Step("Sending fix response")
}
override val topic: String get() = RatesFixProtocol.TOPIC
override val progressTracker = ProgressTracker(RECEIVED, SENDING)
init {
@ -103,9 +86,10 @@ object NodeInterestRates {
@Suspendable
override fun call(): Unit {
val answers = service.oracle.query(request.queries, request.deadline)
val request = receive<QueryRequest>(otherParty).unwrap { it }
val answers = oracle.query(request.queries, request.deadline)
progressTracker.currentStep = SENDING
send(request.replyToParty, answers)
send(otherParty, answers)
}
}

View File

@ -1,20 +1,17 @@
package com.r3corda.demos.protocols
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.FutureCallback
import com.r3corda.core.contracts.DealState
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.protocols.TwoPartyDealProtocol.DEAL_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer
import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
/**
@ -25,58 +22,27 @@ import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
* or the protocol would have to reach out to external systems (or users) to verify the deals.
*/
object AutoOfferProtocol {
val TOPIC = "autooffer.topic"
data class AutoOfferMessage(val notary: Party,
val dealBeingOffered: DealState,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
class Plugin: CordaPluginRegistry() {
class Plugin : CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Secondary.tracker()
}
fun tracker() = ProgressTracker(DEALING)
class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback<SignedTransaction> {
override fun onFailure(t: Throwable?) {
// TODO handle exceptions
}
override fun onSuccess(st: SignedTransaction?) {
success(st!!)
}
}
init {
addProtocolHandler(TOPIC, "$DEAL_TOPIC.seller") { autoOfferMessage: AutoOfferMessage ->
val progressTracker = tracker()
// Put the deal onto the ledger
progressTracker.currentStep = DEALING
Acceptor(
autoOfferMessage.replyToParty,
autoOfferMessage.notary,
autoOfferMessage.dealBeingOffered,
progressTracker.getChildProgressTracker(DEALING)!!
)
}
services.registerProtocolInitiator(Instigator::class) { Acceptor(it) }
}
}
class Requester(val dealToBeOffered: DealState) : ProtocolLogic<SignedTransaction>() {
companion object {
object RECEIVED : ProgressTracker.Step("Received API call")
object ANNOUNCING : ProgressTracker.Step("Announcing to the peer node")
object DEALING : ProgressTracker.Step("Starting the deal protocol") {
override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker()
}
@ -84,10 +50,9 @@ object AutoOfferProtocol {
// We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some
// point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being
// surprised when it appears as a new set of tasks below the current one.
fun tracker() = ProgressTracker(RECEIVED, ANNOUNCING, DEALING)
fun tracker() = ProgressTracker(RECEIVED, DEALING)
}
override val topic: String get() = TOPIC
override val progressTracker = tracker()
init {
@ -100,17 +65,14 @@ object AutoOfferProtocol {
val notary = serviceHub.networkMapCache.notaryNodes.first().identity
// need to pick which ever party is not us
val otherParty = notUs(dealToBeOffered.parties).single()
progressTracker.currentStep = ANNOUNCING
send(otherParty, AutoOfferMessage(notary, dealToBeOffered, serviceHub.storageService.myLegalIdentity))
progressTracker.currentStep = DEALING
val instigator = Instigator(
otherParty,
notary,
dealToBeOffered,
AutoOffer(notary, dealToBeOffered),
serviceHub.storageService.myLegalIdentityKey,
progressTracker.getChildProgressTracker(DEALING)!!
)
val stx = subProtocol(instigator, inheritParentSessions = true)
val stx = subProtocol(instigator)
return stx
}

View File

@ -5,56 +5,49 @@ import co.paralleluniverse.strands.Strand
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache
import java.util.concurrent.TimeUnit
object ExitServerProtocol {
val TOPIC = "exit.topic"
// Will only be enabled if you install the Handler
@Volatile private var enabled = false
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class ExitMessage(val exitCode: Int,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class ExitMessage(val exitCode: Int)
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) {
init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration ->
// Just to validate we got the message
if (enabled) {
val message = msg.data.deserialize<ExitMessage>()
System.exit(message.exitCode)
}
}
services.registerProtocolInitiator(Broadcast::class, ::ExitServerHandler)
enabled = true
}
}
private class ExitServerHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
override fun call() {
// Just to validate we got the message
if (enabled) {
val message = receive<ExitMessage>(otherParty).unwrap { it }
System.exit(message.exitCode)
}
}
}
/**
* This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently
* we do not support coercing numeric types in the reflective search for matching constructors.
*/
class Broadcast(val exitCode: Int) : ProtocolLogic<Boolean>() {
override val topic: String get() = TOPIC
@Suspendable
override fun call(): Boolean {
if (enabled) {
@ -73,7 +66,7 @@ object ExitServerProtocol {
if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore
} else {
send(recipient.identity, ExitMessage(exitCode, recipient.identity))
send(recipient.identity, ExitMessage(exitCode))
}
}
}

View File

@ -4,14 +4,10 @@ import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.demos.DemoClock
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.testing.node.MockNetworkMapCache
import java.time.LocalDate
@ -20,29 +16,28 @@ import java.time.LocalDate
*/
object UpdateBusinessDayProtocol {
val TOPIC = "businessday.topic"
// This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will
// resolve itself when the protocol session stuff is done.
data class UpdateBusinessDayMessage(val date: LocalDate,
override val replyToParty: Party,
override val sendSessionID: Long = random63BitValue(),
override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage
data class UpdateBusinessDayMessage(val date: LocalDate)
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) {
init {
services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration ->
val updateBusinessDayMessage = msg.data.deserialize<UpdateBusinessDayMessage>()
(services.clock as DemoClock).updateDate(updateBusinessDayMessage.date)
}
services.registerProtocolInitiator(Broadcast::class, ::UpdateBusinessDayHandler)
}
}
private class UpdateBusinessDayHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
override fun call() {
val message = receive<UpdateBusinessDayMessage>(otherParty).unwrap { it }
(serviceHub.clock as DemoClock).updateDate(message.date)
}
}
class Broadcast(val date: LocalDate,
override val progressTracker: ProgressTracker = Broadcast.tracker()) : ProtocolLogic<Unit>() {
@ -52,8 +47,6 @@ object UpdateBusinessDayProtocol {
fun tracker() = ProgressTracker(NOTIFYING)
}
override val topic: String get() = TOPIC
@Suspendable
override fun call(): Unit {
progressTracker.currentStep = NOTIFYING
@ -67,7 +60,7 @@ object UpdateBusinessDayProtocol {
if (recipient.address is MockNetworkMapCache.MockAddress) {
// Ignore
} else {
send(recipient.identity, UpdateBusinessDayMessage(date, recipient.identity))
send(recipient.identity, UpdateBusinessDayMessage(date))
}
}
}

View File

@ -10,14 +10,16 @@ import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.contracts.UniqueIdentifier
import com.r3corda.core.failure
import com.r3corda.core.flatMap
import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.testing.connectProtocols
import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor
import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer
import com.r3corda.protocols.TwoPartyDealProtocol.Instigator
import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockIdentityService
import java.security.KeyPair
import java.time.LocalDate
import java.util.*
@ -73,7 +75,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
val node1: SimulatedNode = banks[i]
val node2: SimulatedNode = banks[j]
val swaps: Map<UniqueIdentifier, StateAndRef<InterestRateSwap.State>> = node1.services.vaultService.linearHeadsOfType<com.r3corda.contracts.InterestRateSwap.State>()
val swaps: Map<UniqueIdentifier, StateAndRef<InterestRateSwap.State>> = node1.services.vaultService.linearHeadsOfType<InterestRateSwap.State>()
val theDealRef: StateAndRef<InterestRateSwap.State> = swaps.values.single()
// Do we have any more days left in this deal's lifetime? If not, return.
@ -111,22 +113,19 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
// We load the IRS afresh each time because the leg parts of the structure aren't data classes so they don't
// have the convenient copy() method that'd let us make small adjustments. Instead they're partly mutable.
// TODO: We should revisit this in post-Excalibur cleanup and fix, e.g. by introducing an interface.
val irs = om.readValue<com.r3corda.contracts.InterestRateSwap.State>(javaClass.getResource("trade.json"))
val irs = om.readValue<InterestRateSwap.State>(javaClass.getResource("trade.json"))
irs.fixedLeg.fixedRatePayer = node1.info.identity
irs.floatingLeg.floatingRatePayer = node2.info.identity
val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, node1.keyPair!!)
val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs)
connectProtocols(instigator, acceptor)
val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture }
showProgressFor(listOf(node1, node2))
showConsensusFor(listOf(node1, node2, regulators[0]))
val instigatorFuture: ListenableFuture<SignedTransaction> = node1.services.startProtocol("instigator", instigator)
val instigator = Instigator(node2.info.identity, AutoOffer(notary.info.identity, irs), node1.keyPair!!)
val instigatorTx = node1.services.startProtocol("instigator", instigator)
return Futures.transformAsync(Futures.allAsList(instigatorFuture, node2.services.startProtocol("acceptor", acceptor))) {
instigatorFuture
}
return Futures.transformAsync(Futures.allAsList(instigatorTx, acceptorTx)) { instigatorTx }
}
override fun iterate(): InMemoryMessagingNetwork.MessageTransfer? {

View File

@ -9,12 +9,13 @@ import com.r3corda.core.contracts.DOLLARS
import com.r3corda.core.contracts.OwnableState
import com.r3corda.core.contracts.`issued by`
import com.r3corda.core.days
import com.r3corda.core.flatMap
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.connectProtocols
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork
import java.time.Instant
@ -45,25 +46,24 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
seller.services.recordTransactions(issuance)
val amount = 1000.DOLLARS
val buyerProtocol = TwoPartyTradeProtocol.Buyer(
seller.info.identity,
notary.info.identity,
amount,
CommercialPaper.State::class.java)
val sellerProtocol = TwoPartyTradeProtocol.Seller(
val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) {
Buyer(it, notary.info.identity, amount, CommercialPaper.State::class.java)
}.flatMap { it.resultFuture }
val sellerProtocol = Seller(
buyer.info.identity,
notary.info,
issuance.tx.outRef<OwnableState>(0),
amount,
seller.storage.myLegalIdentityKey)
connectProtocols(buyerProtocol, sellerProtocol)
showConsensusFor(listOf(buyer, seller, notary))
showProgressFor(listOf(buyer, seller))
val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.$TOPIC.buyer", buyerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.$TOPIC.seller", sellerProtocol)
val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.seller", sellerProtocol)
return Futures.successfulAsList(buyerFuture, sellerFuture)
}
}

View File

@ -4,22 +4,28 @@ package com.r3corda.testing
import com.google.common.base.Throwables
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.statemachine.StateMachineManager.Change
import com.r3corda.node.utilities.AddOrRemove.ADD
import com.r3corda.testing.node.MockIdentityService
import com.r3corda.testing.node.MockServices
import rx.Subscriber
import java.net.ServerSocket
import java.security.KeyPair
import java.security.PublicKey
import kotlin.reflect.KClass
/**
* JAVA INTEROP
@ -129,22 +135,32 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
dsl: TransactionDSL<TransactionDSLInterpreter>.() -> EnforceVerifyOrFail
) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) }
/**
* Connect two protocols together for communication. Both protocols must have a property called otherParty of type Party
* which points to the other party in the communication.
* The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty
* protocol requests for it using [markerClass].
* @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request.
*/
fun connectProtocols(protocol1: ProtocolLogic<*>, protocol2: ProtocolLogic<*>) {
inline fun <R, reified P : ProtocolLogic<R>> AbstractNode.initiateSingleShotProtocol(
markerClass: KClass<*>,
noinline protocolFactory: (Party) -> P): ListenableFuture<ProtocolStateMachine<R>> {
services.registerProtocolInitiator(markerClass, protocolFactory)
data class Handshake(override val replyToParty: Party,
override val sendSessionID: Long,
override val receiveSessionID: Long) : HandshakeMessage
val future = SettableFuture.create<ProtocolStateMachine<R>>()
val sessionId1 = random63BitValue()
val sessionId2 = random63BitValue()
protocol1.registerSession(Handshake(protocol1.otherParty, sessionId1, sessionId2))
protocol2.registerSession(Handshake(protocol2.otherParty, sessionId2, sessionId1))
}
val subscriber = object : Subscriber<Change>() {
override fun onNext(change: Change) {
if (change.logic is P && change.addOrRemove == ADD) {
unsubscribe()
future.set(change.logic.psm as ProtocolStateMachine<R>)
}
}
override fun onError(e: Throwable) {
future.setException(e)
}
override fun onCompleted() {}
}
private val ProtocolLogic<*>.otherParty: Party
get() = javaClass.getDeclaredField("otherParty").apply { isAccessible = true }.get(this) as Party
smm.changes.subscribe(subscriber)
return future
}

View File

@ -131,6 +131,10 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// It is used from the network visualiser tool.
@Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!!
fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? {
return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block)
}
fun send(topic: String, target: MockNode, payload: Any) {
services.networkService.send(TopicSession(topic), payload, target.info.address)
}