From 67fdf9b2ffb566f05d4a03fc6314334c3f47f052 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Tue, 27 Sep 2016 18:25:26 +0100 Subject: [PATCH] Automatic session management between two protocols, and removal of explict topics --- .gitignore | 2 +- .../main/kotlin/com/r3corda/contracts/IRS.kt | 4 +- .../protocols/TwoPartyTradeProtocol.kt | 17 +- .../src/main/kotlin/com/r3corda/core/Utils.kt | 18 +- .../r3corda/core/protocols/ProtocolLogic.kt | 86 +--- .../core/protocols/ProtocolStateMachine.kt | 18 +- .../AbstractStateReplacementProtocol.kt | 10 - .../protocols/BroadcastTransactionProtocol.kt | 21 +- .../protocols/FetchAttachmentsProtocol.kt | 6 - .../r3corda/protocols/FetchDataProtocol.kt | 13 +- .../protocols/FetchTransactionsProtocol.kt | 8 +- .../com/r3corda/protocols/FinalityProtocol.kt | 3 - .../r3corda/protocols/NotaryChangeProtocol.kt | 6 - .../com/r3corda/protocols/NotaryProtocol.kt | 43 +- .../com/r3corda/protocols/RatesFixProtocol.kt | 67 ++- .../protocols/ResolveTransactionsProtocol.kt | 2 - .../protocols/ServiceRequestMessage.kt | 15 - .../r3corda/protocols/TwoPartyDealProtocol.kt | 123 +++--- .../ProtocolLogicRefFromJavaTest.java | 21 +- .../BroadcastTransactionProtocolTest.kt | 17 +- .../core/protocols/ProtocolLogicRefTest.kt | 3 - docs/source/protocol-state-machines.rst | 81 ++-- .../com/r3corda/node/internal/AbstractNode.kt | 15 + .../node/services/NotaryChangeService.kt | 10 +- .../node/services/api/AbstractNodeService.kt | 40 -- .../node/services/api/CheckpointStorage.kt | 23 +- .../node/services/api/ServiceHubInternal.kt | 23 +- .../clientapi/FixingSessionInitiation.kt | 15 +- .../services/monitor/NodeMonitorService.kt | 6 +- .../persistence/DataVendingService.kt | 122 +++--- .../statemachine/ProtocolIORequest.kt | 37 +- .../statemachine/ProtocolStateMachineImpl.kt | 174 ++++++-- .../statemachine/StateMachineManager.kt | 399 ++++++++++-------- .../services/transactions/NotaryService.kt | 21 +- .../transactions/SimpleNotaryService.kt | 9 +- .../transactions/ValidatingNotaryService.kt | 12 +- .../messaging/TwoPartyTradeProtocolTests.kt | 61 ++- .../node/services/MockServiceHubInternal.kt | 34 +- .../node/services/NodeSchedulerServiceTest.kt | 13 +- .../node/services/NotaryChangeTests.kt | 7 +- .../node/services/NotaryServiceTests.kt | 33 +- .../services/ValidatingNotaryServiceTests.kt | 18 +- .../persistence/DataVendingServiceTests.kt | 34 +- .../PerFileCheckpointStorageTests.kt | 8 +- .../statemachine/StateMachineManagerTests.kt | 162 ++++--- .../kotlin/com/r3corda/demos/TraderDemo.kt | 37 +- .../r3corda/demos/api/NodeInterestRates.kt | 52 +-- .../demos/protocols/AutoOfferProtocol.kt | 54 +-- .../demos/protocols/ExitServerProtocol.kt | 37 +- .../protocols/UpdateBusinessDayProtocol.kt | 29 +- .../com/r3corda/simulation/IRSSimulation.kt | 23 +- .../com/r3corda/simulation/TradeSimulation.kt | 24 +- .../com/r3corda/testing/CoreTestUtils.kt | 48 ++- .../com/r3corda/testing/node/MockNode.kt | 4 + 54 files changed, 1055 insertions(+), 1113 deletions(-) diff --git a/.gitignore b/.gitignore index c2649c7542..5078c61a1f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ tags .DS_Store *.log -*.log.gz *.orig # Created by .ignore support plugin (hsz.mobi) @@ -100,3 +99,4 @@ crashlytics-build.properties # docs related docs/virtualenv/ +/logs/ diff --git a/contracts/src/main/kotlin/com/r3corda/contracts/IRS.kt b/contracts/src/main/kotlin/com/r3corda/contracts/IRS.kt index 07eece75f5..54342c20bc 100644 --- a/contracts/src/main/kotlin/com/r3corda/contracts/IRS.kt +++ b/contracts/src/main/kotlin/com/r3corda/contracts/IRS.kt @@ -675,8 +675,8 @@ class InterestRateSwap() : Contract { val nextFixingOf = nextFixingOf() ?: return null // This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects - val (instant, duration) = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay) - return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef, duration), instant) + val instant = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay).start + return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef), instant) } override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary) diff --git a/contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt b/contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt index 5748fc36f5..ccb6a33302 100644 --- a/contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt +++ b/contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt @@ -46,22 +46,20 @@ import java.util.* // and [AbstractStateReplacementProtocol]. object TwoPartyTradeProtocol { - val TOPIC = "platform.trade" - class UnacceptablePriceException(val givenPrice: Amount) : Exception("Unacceptable price: $givenPrice") class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName" } // This object is serialised to the network and is the first protocol message the seller sends to the buyer. - class SellerTradeInfo( + data class SellerTradeInfo( val assetForSale: StateAndRef, val price: Amount, val sellerOwnerKey: PublicKey ) - class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, - val notarySig: DigitalSignature.LegallyIdentifiable) + data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, + val notarySig: DigitalSignature.LegallyIdentifiable) open class Seller(val otherParty: Party, val notaryNode: NodeInfo, @@ -84,8 +82,6 @@ object TwoPartyTradeProtocol { fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS) } - override val topic: String get() = TOPIC - @Suspendable override fun call(): SignedTransaction { val partialTX: SignedTransaction = receiveAndCheckProposedTransaction() @@ -172,7 +168,6 @@ object TwoPartyTradeProtocol { object SWAPPING_SIGNATURES : ProgressTracker.Step("Swapping signatures with the seller") - override val topic: String get() = TOPIC override val progressTracker = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES) @Suspendable @@ -197,7 +192,7 @@ object TwoPartyTradeProtocol { @Suspendable private fun receiveAndValidateTradeRequest(): SellerTradeInfo { progressTracker.currentStep = RECEIVING - // Wait for a trade request to come in on our pre-provided session ID. + // Wait for a trade request to come in from the other side val maybeTradeRequest = receive(otherParty) progressTracker.currentStep = VERIFYING @@ -243,8 +238,8 @@ object TwoPartyTradeProtocol { private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair> { val ptx = TransactionType.General.Builder(notary) // Add input and output states for the movement of cash, by using the Cash contract to generate the states. - val wallet = serviceHub.vaultService.currentVault - val cashStates = wallet.statesOfType() + val vault = serviceHub.vaultService.currentVault + val cashStates = vault.statesOfType() val cashSigningPubKeys = Cash().generateSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, cashStates) // Add inputs/outputs/a command for the movement of the asset. ptx.addInputState(tradeRequest.assetForSale) diff --git a/core/src/main/kotlin/com/r3corda/core/Utils.kt b/core/src/main/kotlin/com/r3corda/core/Utils.kt index f294eebebc..0fa7447194 100644 --- a/core/src/main/kotlin/com/r3corda/core/Utils.kt +++ b/core/src/main/kotlin/com/r3corda/core/Utils.kt @@ -1,5 +1,6 @@ package com.r3corda.core +import com.google.common.base.Function import com.google.common.base.Throwables import com.google.common.io.ByteStreams import com.google.common.util.concurrent.Futures @@ -17,8 +18,8 @@ import java.nio.file.Files import java.nio.file.Path import java.time.Duration import java.time.temporal.Temporal +import java.util.concurrent.ExecutionException import java.util.concurrent.Executor -import java.util.concurrent.Future import java.util.concurrent.locks.ReentrantLock import java.util.zip.ZipInputStream import kotlin.concurrent.withLock @@ -66,19 +67,20 @@ fun ListenableFuture.success(executor: Executor, body: (T) -> Unit) = the fun ListenableFuture.failure(executor: Executor, body: (Throwable) -> Unit) = then(executor) { try { get() - } catch(e: Throwable) { - body(e) + } catch (e: ExecutionException) { + body(e.cause!!) + } catch (t: Throwable) { + body(t) } } -infix fun Future.map(mapper: (F) -> T): Future = Futures.lazyTransform(this) { mapper(it!!) } infix fun ListenableFuture.then(body: () -> Unit): ListenableFuture = apply { then(RunOnCallerThread, body) } infix fun ListenableFuture.success(body: (T) -> Unit): ListenableFuture = apply { success(RunOnCallerThread, body) } infix fun ListenableFuture.failure(body: (Throwable) -> Unit): ListenableFuture = apply { failure(RunOnCallerThread, body) } - -fun Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block) - +infix fun ListenableFuture.map(mapper: (F) -> T): ListenableFuture = Futures.transform(this, Function { mapper(it!!) }) +infix fun ListenableFuture.flatMap(mapper: (F) -> ListenableFuture): ListenableFuture = Futures.transformAsync(this) { mapper(it!!) } /** Executes the given block and sets the future to either the result, or any exception that was thrown. */ +// TODO This is not used but there's existing code that can be replaced by this fun SettableFuture.setFrom(logger: Logger? = null, block: () -> T): SettableFuture { try { set(block()) @@ -89,6 +91,8 @@ fun SettableFuture.setFrom(logger: Logger? = null, block: () -> T): Setta return this } +fun Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block) + // Simple infix function to add back null safety that the JDK lacks: timeA until timeB infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive) diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt index 91e5bfed22..6921e04a71 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt @@ -2,19 +2,11 @@ package com.r3corda.core.protocols import co.paralleluniverse.fibers.Suspendable import com.r3corda.core.crypto.Party -import com.r3corda.core.messaging.Message -import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.node.ServiceHub -import com.r3corda.core.node.services.DEFAULT_SESSION_ID -import com.r3corda.core.serialization.deserialize import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.UntrustworthyData -import com.r3corda.core.utilities.debug -import com.r3corda.protocols.HandshakeMessage import org.slf4j.Logger import rx.Observable -import java.util.* -import java.util.concurrent.CompletableFuture /** * A sub-class of [ProtocolLogic] implements a protocol flow using direct, straight line blocking code. Thus you @@ -48,23 +40,14 @@ abstract class ProtocolLogic { */ val serviceHub: ServiceHub get() = psm.serviceHub - /** - * The topic to use when communicating with other parties. If more than one topic is required then use sub-protocols. - * Note that this is temporary until protocol sessions are properly implemented. - */ - protected abstract val topic: String - - private val sessions = HashMap() + private var sessionProtocol: ProtocolLogic<*> = this /** - * If a node receives a [HandshakeMessage] it needs to call this method on the initiated receipt protocol to enable - * communication between it and the sender protocol. Calling this method, and other initiation steps, are already - * handled by AbstractNodeService.addProtocolHandler. + * Return the marker [Class] which [party] has used to register the counterparty protocol that is to execute on the + * other side. The default implementation returns the class object of this ProtocolLogic, but any [Class] instance + * will do as long as the other side registers with it. */ - fun registerSession(receivedHandshake: HandshakeMessage) { - // Note that the send and receive session IDs are swapped - addSession(receivedHandshake.replyToParty, receivedHandshake.receiveSessionID, receivedHandshake.sendSessionID) - } + open fun getCounterpartyMarker(party: Party): Class<*> = javaClass // Kotlin helpers that allow the use of generic types. inline fun sendAndReceive(otherParty: Party, payload: Any): UntrustworthyData { @@ -73,69 +56,41 @@ abstract class ProtocolLogic { @Suspendable fun sendAndReceive(otherParty: Party, payload: Any, receiveType: Class): UntrustworthyData { - val sendSessionId = getSendSessionId(otherParty, payload) - val receiveSessionId = getReceiveSessionId(otherParty) - return psm.sendAndReceive(topic, otherParty, sendSessionId, receiveSessionId, payload, receiveType) + return psm.sendAndReceive(otherParty, payload, receiveType, sessionProtocol) } inline fun receive(otherParty: Party): UntrustworthyData = receive(otherParty, T::class.java) @Suspendable fun receive(otherParty: Party, receiveType: Class): UntrustworthyData { - return psm.receive(topic, getReceiveSessionId(otherParty), receiveType) + return psm.receive(otherParty, receiveType, sessionProtocol) } @Suspendable fun send(otherParty: Party, payload: Any) { - psm.send(topic, otherParty, getSendSessionId(otherParty, payload), payload) + psm.send(otherParty, payload, sessionProtocol) } - private fun addSession(party: Party, sendSesssionId: Long, receiveSessionId: Long) { - if (party in sessions) { - logger.debug { "Existing session with party $party to be overwritten by new one" } - } - sessions[party] = Session(sendSesssionId, receiveSessionId) - } - - private fun getSendSessionId(otherParty: Party, payload: Any): Long { - return if (payload is HandshakeMessage) { - addSession(otherParty, payload.sendSessionID, payload.receiveSessionID) - DEFAULT_SESSION_ID - } else { - sessions[otherParty]?.sendSessionId ?: - throw IllegalStateException("Session with party $otherParty hasn't been established yet") - } - } - - private fun getReceiveSessionId(otherParty: Party): Long { - return sessions[otherParty]?.receiveSessionId ?: - throw IllegalStateException("Session with party $otherParty hasn't been established yet") - } - - /** - * Check if we already have a session with this party - */ - protected fun hasSession(otherParty: Party) = sessions.containsKey(otherParty) - /** * Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the * [ProtocolStateMachine] and then calling the [call] method. - * @param inheritParentSessions In certain situations the subprotocol needs to inherit and use the same open - * sessions of the parent. However in most cases this is not desirable as it prevents the subprotocol from - * communicating with the same party on a different topic. For this reason the default value is false. + * @param shareParentSessions In certain situations the need arises to use the same sessions the parent protocol has + * already established. However this also prevents the subprotocol from creating new sessions with those parties. + * For this reason the default value is false. */ - @JvmOverloads + // TODO Rethink the default value for shareParentSessions + // TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way @Suspendable - fun subProtocol(subLogic: ProtocolLogic, inheritParentSessions: Boolean = false): R { + fun subProtocol(subLogic: ProtocolLogic, shareParentSessions: Boolean = false): R { subLogic.psm = psm - if (inheritParentSessions) { - subLogic.sessions.putAll(sessions) - } maybeWireUpProgressTracking(subLogic) - val r = subLogic.call() + if (shareParentSessions) { + subLogic.sessionProtocol = this + } + val result = subLogic.call() // It's easy to forget this when writing protocols so we just step it to the DONE state when it completes. subLogic.progressTracker?.currentStep = ProgressTracker.DONE - return r + return result } private fun maybeWireUpProgressTracking(subLogic: ProtocolLogic<*>) { @@ -166,12 +121,11 @@ abstract class ProtocolLogic { @Suspendable abstract fun call(): T - private data class Session(val sendSessionId: Long, val receiveSessionId: Long) - // TODO this is not threadsafe, needs an atomic get-step-and-subscribe fun track(): Pair>? { return progressTracker?.let { Pair(it.currentStep.toString(), it.changes.map { it.toString() }) } } + } diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt index da0abc3c22..a6b8c3218b 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt @@ -1,7 +1,6 @@ package com.r3corda.core.protocols import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.strands.Strand import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.crypto.Party import com.r3corda.core.node.ServiceHub @@ -10,9 +9,12 @@ import org.slf4j.Logger import java.util.* data class StateMachineRunId private constructor(val uuid: UUID) { + companion object { fun createRandom(): StateMachineRunId = StateMachineRunId(UUID.randomUUID()) } + + override fun toString(): String = "${javaClass.simpleName}($uuid)" } /** @@ -20,18 +22,16 @@ data class StateMachineRunId private constructor(val uuid: UUID) { */ interface ProtocolStateMachine { @Suspendable - fun sendAndReceive(topic: String, - destination: Party, - sessionIDForSend: Long, - sessionIDForReceive: Long, + fun sendAndReceive(otherParty: Party, payload: Any, - receiveType: Class): UntrustworthyData + receiveType: Class, + sessionProtocol: ProtocolLogic<*>): UntrustworthyData @Suspendable - fun receive(topic: String, sessionIDForReceive: Long, receiveType: Class): UntrustworthyData + fun receive(otherParty: Party, receiveType: Class, sessionProtocol: ProtocolLogic<*>): UntrustworthyData @Suspendable - fun send(topic: String, destination: Party, sessionID: Long, payload: Any) + fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) val serviceHub: ServiceHub val logger: Logger @@ -41,3 +41,5 @@ interface ProtocolStateMachine { /** This future will complete when the call method returns. */ val resultFuture: ListenableFuture } + +class ProtocolSessionException(message: String) : Exception(message) \ No newline at end of file diff --git a/core/src/main/kotlin/com/r3corda/protocols/AbstractStateReplacementProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/AbstractStateReplacementProtocol.kt index 0cb21eb600..247f0422b5 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/AbstractStateReplacementProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/AbstractStateReplacementProtocol.kt @@ -9,7 +9,6 @@ import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.node.recordTransactions import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.utilities.ProgressTracker @@ -35,10 +34,6 @@ abstract class AbstractStateReplacementProtocol { val stx: SignedTransaction } - data class Handshake(override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage - abstract class Instigator(val originalState: StateAndRef, val modification: T, override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic>() { @@ -94,11 +89,6 @@ abstract class AbstractStateReplacementProtocol { private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey { val proposal = assembleProposal(originalState.ref, modification, stx) - // TODO: Move this into protocol logic as a func on the lines of handshake(Party, HandshakeMessage) - if (!hasSession(party)) { - send(party, Handshake(serviceHub.storageService.myLegalIdentity)) - } - val response = sendAndReceive(party, proposal) val participantSignature = response.unwrap { if (it.sig == null) throw StateReplacementException(it.error!!) diff --git a/core/src/main/kotlin/com/r3corda/protocols/BroadcastTransactionProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/BroadcastTransactionProtocol.kt index 0d1e868c9c..1c405442e0 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/BroadcastTransactionProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/BroadcastTransactionProtocol.kt @@ -5,7 +5,6 @@ import com.r3corda.core.contracts.ClientToServiceCommand import com.r3corda.core.crypto.Party import com.r3corda.core.node.recordTransactions import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.transactions.SignedTransaction @@ -25,31 +24,17 @@ import com.r3corda.core.transactions.SignedTransaction class BroadcastTransactionProtocol(val notarisedTransaction: SignedTransaction, val events: Set, val participants: Set) : ProtocolLogic() { - companion object { - /** Topic for messages notifying a node of a new transaction */ - val TOPIC = "platform.wallet.notify_tx" - } - override val topic: String = TOPIC - - data class NotifyTxRequestMessage(val tx: SignedTransaction, - val events: Set, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class NotifyTxRequest(val tx: SignedTransaction, val events: Set) @Suspendable override fun call() { // Record it locally serviceHub.recordTransactions(notarisedTransaction) - // TODO: Messaging layer should handle this broadcast for us (although we need to not be sending - // session ID, for that to work, as well). + // TODO: Messaging layer should handle this broadcast for us + val msg = NotifyTxRequest(notarisedTransaction, events) participants.filter { it != serviceHub.storageService.myLegalIdentity }.forEach { participant -> - val msg = NotifyTxRequestMessage( - notarisedTransaction, - events, - serviceHub.storageService.myLegalIdentity) send(participant, msg) } } diff --git a/core/src/main/kotlin/com/r3corda/protocols/FetchAttachmentsProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/FetchAttachmentsProtocol.kt index 0e1c579b2b..c8f0ed32ae 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/FetchAttachmentsProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/FetchAttachmentsProtocol.kt @@ -14,12 +14,6 @@ import java.io.InputStream class FetchAttachmentsProtocol(requests: Set, otherSide: Party) : FetchDataProtocol(requests, otherSide) { - companion object { - const val TOPIC = "platform.fetch.attachment" - } - - override val topic: String get() = TOPIC - override fun load(txid: SecureHash): Attachment? = serviceHub.storageService.attachments.openAttachment(txid) override fun convert(wire: ByteArray): Attachment { diff --git a/core/src/main/kotlin/com/r3corda/protocols/FetchDataProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/FetchDataProtocol.kt index e3d2a06064..e1388b6973 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/FetchDataProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/FetchDataProtocol.kt @@ -5,7 +5,6 @@ import com.r3corda.core.contracts.NamedByHash import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.protocols.FetchDataProtocol.DownloadedVsRequestedDataMismatch import com.r3corda.protocols.FetchDataProtocol.HashNotFound @@ -21,8 +20,8 @@ import java.util.* * [HashNotFound] exception being thrown. * * By default this class does not insert data into any local database, if you want to do that after missing items were - * fetched then override [maybeWriteToDisk]. You *must* override [load] and [queryTopic]. If the wire type is not the - * same as the ultimate type, you must also override [convert]. + * fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the + * ultimate type, you must also override [convert]. * * @param T The ultimate type of the data being fetched. * @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type. @@ -35,10 +34,7 @@ abstract class FetchDataProtocol( class HashNotFound(val requested: SecureHash) : BadAnswer() class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : BadAnswer() - data class Request(val hashes: List, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class Request(val hashes: List) data class Result(val fromDisk: List, val downloaded: List) @Suspendable @@ -51,9 +47,8 @@ abstract class FetchDataProtocol( } else { logger.trace("Requesting ${toFetch.size} dependency(s) for verification") - val fetchReq = Request(toFetch, serviceHub.storageService.myLegalIdentity) // TODO: Support "large message" response streaming so response sizes are not limited by RAM. - val maybeItems = sendAndReceive>(otherSide, fetchReq) + val maybeItems = sendAndReceive>(otherSide, Request(toFetch)) // Check for a buggy/malicious peer answering with something that we didn't ask for. val downloaded = validateFetchResponse(maybeItems, toFetch) maybeWriteToDisk(downloaded) diff --git a/core/src/main/kotlin/com/r3corda/protocols/FetchTransactionsProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/FetchTransactionsProtocol.kt index 1ceacafd0b..3eeedf40f6 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/FetchTransactionsProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/FetchTransactionsProtocol.kt @@ -1,8 +1,8 @@ package com.r3corda.protocols -import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.transactions.SignedTransaction /** * Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them. @@ -15,11 +15,5 @@ import com.r3corda.core.crypto.SecureHash class FetchTransactionsProtocol(requests: Set, otherSide: Party) : FetchDataProtocol(requests, otherSide) { - companion object { - const val TOPIC = "platform.fetch.tx" - } - - override val topic: String get() = TOPIC - override fun load(txid: SecureHash): SignedTransaction? = serviceHub.storageService.validatedTransactions.getTransaction(txid) } \ No newline at end of file diff --git a/core/src/main/kotlin/com/r3corda/protocols/FinalityProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/FinalityProtocol.kt index b0065938a5..afcce9c967 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/FinalityProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/FinalityProtocol.kt @@ -31,9 +31,6 @@ class FinalityProtocol(val transaction: SignedTransaction, fun tracker() = ProgressTracker(NOTARISING, BROADCASTING) } - override val topic: String - get() = throw UnsupportedOperationException() - @Suspendable override fun call() { // TODO: Resolve the tx here: it's probably already been done, but re-resolution is a no-op and it'll make the API more forgiving. diff --git a/core/src/main/kotlin/com/r3corda/protocols/NotaryChangeProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/NotaryChangeProtocol.kt index d5222b3f0f..618888922e 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/NotaryChangeProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/NotaryChangeProtocol.kt @@ -24,8 +24,6 @@ import java.security.PublicKey */ object NotaryChangeProtocol: AbstractStateReplacementProtocol() { - val TOPIC = "platform.notary.change" - data class Proposal(override val stateRef: StateRef, override val modification: Party, override val stx: SignedTransaction) : AbstractStateReplacementProtocol.Proposal @@ -35,8 +33,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol() { progressTracker: ProgressTracker = tracker()) : AbstractStateReplacementProtocol.Instigator(originalState, newNotary, progressTracker) { - override val topic: String get() = TOPIC - override fun assembleProposal(stateRef: StateRef, modification: Party, stx: SignedTransaction): AbstractStateReplacementProtocol.Proposal = Proposal(stateRef, modification, stx) @@ -56,8 +52,6 @@ object NotaryChangeProtocol: AbstractStateReplacementProtocol() { override val progressTracker: ProgressTracker = tracker()) : AbstractStateReplacementProtocol.Acceptor(otherSide) { - override val topic: String get() = TOPIC - /** * Check the notary change proposal. * diff --git a/core/src/main/kotlin/com/r3corda/protocols/NotaryProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/NotaryProtocol.kt index 27d5c769a5..36905b24ae 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/NotaryProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/NotaryProtocol.kt @@ -5,12 +5,10 @@ import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SignedData import com.r3corda.core.crypto.signWithECDSA -import com.r3corda.core.messaging.Ack import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.UniquenessException import com.r3corda.core.node.services.UniquenessProvider import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.serialize import com.r3corda.core.transactions.SignedTransaction @@ -21,8 +19,6 @@ import java.security.PublicKey object NotaryProtocol { - val TOPIC = "platform.notary" - /** * A protocol to be used for obtaining a signature from a [NotaryService] ascertaining the transaction * timestamp is correct and none of its inputs have been used in another completed transaction. @@ -30,8 +26,8 @@ object NotaryProtocol { * @throws NotaryException in case the any of the inputs to the transaction have been consumed * by another transaction or the timestamp is invalid. */ - class Client(private val stx: SignedTransaction, - override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic() { + open class Client(private val stx: SignedTransaction, + override val progressTracker: ProgressTracker = Client.tracker()) : ProtocolLogic() { companion object { @@ -42,8 +38,6 @@ object NotaryProtocol { fun tracker() = ProgressTracker(REQUESTING, VALIDATING) } - override val topic: String get() = TOPIC - lateinit var notaryParty: Party @Suspendable @@ -51,9 +45,9 @@ object NotaryProtocol { progressTracker.currentStep = REQUESTING val wtx = stx.tx notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") - check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" } - - sendAndReceive(notaryParty, Handshake(serviceHub.storageService.myLegalIdentity)) + check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { + "Input states must have the same Notary" + } val request = SignRequest(stx, serviceHub.storageService.myLegalIdentity) val response = sendAndReceive(notaryParty, request) @@ -80,6 +74,10 @@ object NotaryProtocol { } } + + class ValidatingClient(stx: SignedTransaction) : Client(stx) + + /** * Checks that the timestamp command is valid (if present) and commits the input state, or returns a conflict * if any of the input states have been previously committed. @@ -92,11 +90,9 @@ object NotaryProtocol { val timestampChecker: TimestampChecker, val uniquenessProvider: UniquenessProvider) : ProtocolLogic() { - override val topic: String get() = TOPIC - @Suspendable override fun call() { - val (stx, reqIdentity) = sendAndReceive(otherSide, Ack).unwrap { it } + val (stx, reqIdentity) = receive(otherSide).unwrap { it } val wtx = stx.tx val result = try { @@ -148,10 +144,6 @@ object NotaryProtocol { } } - data class Handshake(override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage - /** TODO: The caller must authenticate instead of just specifying its identity */ data class SignRequest(val tx: SignedTransaction, val callerIdentity: Party) @@ -162,23 +154,10 @@ object NotaryProtocol { } } - interface Factory { - fun create(otherSide: Party, - timestampChecker: TimestampChecker, - uniquenessProvider: UniquenessProvider): Service - } - - object DefaultFactory : Factory { - override fun create(otherSide: Party, - timestampChecker: TimestampChecker, - uniquenessProvider: UniquenessProvider): Service { - return Service(otherSide, timestampChecker, uniquenessProvider) - } - } } class NotaryException(val error: NotaryError) : Exception() { - override fun toString() = "${super.toString()}: Error response from Notary - ${error.toString()}" + override fun toString() = "${super.toString()}: Error response from Notary - $error" } sealed class NotaryError { diff --git a/core/src/main/kotlin/com/r3corda/protocols/RatesFixProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/RatesFixProtocol.kt index 9dd20d9c2f..e955ddbb61 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/RatesFixProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/RatesFixProtocol.kt @@ -6,7 +6,6 @@ import com.r3corda.core.contracts.FixOf import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.Party import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.utilities.ProgressTracker @@ -34,8 +33,6 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder, override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic() { companion object { - val TOPIC = "platform.rates.interest.fix" - class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate") object WORKING : ProgressTracker.Step("Working with data returned by oracle") object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle") @@ -43,31 +40,22 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder, fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING) } - override val topic: String get() = TOPIC - class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount") - data class QueryRequest(val queries: List, - val deadline: Instant, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage - - data class SignRequest(val tx: WireTransaction, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class QueryRequest(val queries: List, val deadline: Instant) + data class SignRequest(val tx: WireTransaction) @Suspendable override fun call() { progressTracker.currentStep = progressTracker.steps[1] - val fix = query() + val fix = subProtocol(FixQueryProtocol(fixOf, oracle)) progressTracker.currentStep = WORKING checkFixIsNearExpected(fix) tx.addCommand(fix, oracle.owningKey) beforeSigning(fix) progressTracker.currentStep = SIGNING - tx.addSignatureUnchecked(sign()) + val signature = subProtocol(FixSignProtocol(tx, oracle)) + tx.addSignatureUnchecked(signature) } /** @@ -86,31 +74,36 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder, } } - @Suspendable - private fun sign(): DigitalSignature.LegallyIdentifiable { - val wtx = tx.toWireTransaction() - val req = SignRequest(wtx, serviceHub.storageService.myLegalIdentity) - val resp = sendAndReceive(oracle, req) - return resp.unwrap { sig -> - check(sig.signer == oracle) - tx.checkSignature(sig) - sig + class FixQueryProtocol(val fixOf: FixOf, val oracle: Party) : ProtocolLogic() { + @Suspendable + override fun call(): Fix { + val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end + // TODO: add deadline to receive + val resp = sendAndReceive>(oracle, QueryRequest(listOf(fixOf), deadline)) + + return resp.unwrap { + val fix = it.first() + // Check the returned fix is for what we asked for. + check(fix.of == fixOf) + fix + } } } - @Suspendable - private fun query(): Fix { - val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end - val req = QueryRequest(listOf(fixOf), deadline, serviceHub.storageService.myLegalIdentity) - // TODO: add deadline to receive - val resp = sendAndReceive>(oracle, req) - return resp.unwrap { - val fix = it.first() - // Check the returned fix is for what we asked for. - check(fix.of == fixOf) - fix + class FixSignProtocol(val tx: TransactionBuilder, val oracle: Party) : ProtocolLogic() { + @Suspendable + override fun call(): DigitalSignature.LegallyIdentifiable { + val wtx = tx.toWireTransaction() + val resp = sendAndReceive(oracle, SignRequest(wtx)) + + return resp.unwrap { sig -> + check(sig.signer == oracle) + tx.checkSignature(sig) + sig + } } } + } diff --git a/core/src/main/kotlin/com/r3corda/protocols/ResolveTransactionsProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/ResolveTransactionsProtocol.kt index 749fe4a372..d8420df0aa 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/ResolveTransactionsProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/ResolveTransactionsProtocol.kt @@ -127,8 +127,6 @@ class ResolveTransactionsProtocol(private val txHashes: Set, return result } - override val topic: String get() = throw UnsupportedOperationException() - @Suspendable private fun downloadDependencies(depsToCheck: Set): Collection { // Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth diff --git a/core/src/main/kotlin/com/r3corda/protocols/ServiceRequestMessage.kt b/core/src/main/kotlin/com/r3corda/protocols/ServiceRequestMessage.kt index 811f0f5365..b8391acf90 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/ServiceRequestMessage.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/ServiceRequestMessage.kt @@ -32,19 +32,4 @@ interface PartyRequestMessage : ServiceRequestMessage { override fun getReplyTo(networkMapCache: NetworkMapCache): MessageRecipients { return networkMapCache.partyNodes.single { it.identity == replyToParty }.address } -} - -/** - * A Handshake message is sent to initiate communication between two protocol instances. It contains the two session IDs - * the two protocols will need to communicate. - * Note: This is a temperary interface and will be removed once the protocol session work is implemented. - */ -interface HandshakeMessage : PartyRequestMessage { - - val sendSessionID: Long - val receiveSessionID: Long - @Deprecated("sessionID functions as receiveSessionID but it's recommended to use the later for clarity", - replaceWith = ReplaceWith("receiveSessionID")) - override val sessionID: Long get() = receiveSessionID - } \ No newline at end of file diff --git a/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt b/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt index 721e5a0821..c8caa8d2e1 100644 --- a/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt +++ b/core/src/main/kotlin/com/r3corda/protocols/TwoPartyDealProtocol.kt @@ -6,12 +6,10 @@ import com.r3corda.core.contracts.* import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.signWithECDSA -import com.r3corda.core.crypto.toBase58String import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.recordTransactions import com.r3corda.core.node.services.ServiceType import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue import com.r3corda.core.seconds import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.TransactionBuilder @@ -22,7 +20,6 @@ import com.r3corda.core.utilities.trace import java.math.BigDecimal import java.security.KeyPair import java.security.PublicKey -import java.time.Duration /** * Classes for manipulating a two party deal or agreement. @@ -36,10 +33,6 @@ import java.time.Duration */ object TwoPartyDealProtocol { - val DEAL_TOPIC = "platform.deal" - /** This topic exists purely for [FixingSessionInitiation] to be sent from [FixingRoleDecider] to [FixingSessionInitiationHandler] */ - val FIX_INITIATE_TOPIC = "platform.fix.initiate" - class DealMismatchException(val expectedDeal: ContractState, val actualDeal: ContractState) : Exception() { override fun toString() = "The submitted deal didn't match the expected: $expectedDeal vs $actualDeal" } @@ -53,13 +46,19 @@ object TwoPartyDealProtocol { class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable) + /** + * [Primary] at the end sends the signed tx to all the regulator parties. This a seperate workflow which needs a + * sepearate session with the regulator. This interface is used to do that in [Primary.getCounterpartyMarker]. + */ + interface MarkerForBogusRegulatorProtocol + /** * Abstracted bilateral deal protocol participant that initiates communication/handshake. * * There's a good chance we can push at least some of this logic down into core protocol logic * and helper methods etc. */ - abstract class Primary(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic() { + abstract class Primary(override val progressTracker: ProgressTracker = Primary.tracker()) : ProtocolLogic() { companion object { object AWAITING_PROPOSAL : ProgressTracker.Step("Handshaking and awaiting transaction proposal") @@ -73,13 +72,19 @@ object TwoPartyDealProtocol { fun tracker() = ProgressTracker(AWAITING_PROPOSAL, VERIFYING, SIGNING, NOTARY, SENDING_SIGS, RECORDING, COPYING_TO_REGULATOR) } - override val topic: String get() = DEAL_TOPIC - - abstract val payload: U + abstract val payload: Any abstract val notaryNode: NodeInfo abstract val otherParty: Party abstract val myKeyPair: KeyPair + override fun getCounterpartyMarker(party: Party): Class<*> { + return if (serviceHub.networkMapCache.regulators.any { it.identity == party }) { + MarkerForBogusRegulatorProtocol::class.java + } else { + super.getCounterpartyMarker(party) + } + } + @Suspendable fun getPartialTransaction(): UntrustworthyData { progressTracker.currentStep = AWAITING_PROPOSAL @@ -199,8 +204,6 @@ object TwoPartyDealProtocol { fun tracker() = ProgressTracker(RECEIVING, VERIFYING, SIGNING, SWAPPING_SIGNATURES, RECORDING) } - override val topic: String get() = DEAL_TOPIC - abstract val otherParty: Party @Suspendable @@ -234,9 +237,7 @@ object TwoPartyDealProtocol { val handshake = receive>(otherParty) progressTracker.currentStep = VERIFYING - handshake.unwrap { - return validateHandshake(it) - } + return handshake.unwrap { validateHandshake(it) } } @Suspendable @@ -263,47 +264,45 @@ object TwoPartyDealProtocol { @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake): Pair> } + + data class AutoOffer(val notary: Party, val dealBeingOffered: DealState) + + /** * One side of the protocol for inserting a pre-agreed deal. */ - open class Instigator(override val otherParty: Party, - val notary: Party, - override val payload: T, - override val myKeyPair: KeyPair, - override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() { + open class Instigator(override val otherParty: Party, + override val payload: AutoOffer, + override val myKeyPair: KeyPair, + override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() { override val notaryNode: NodeInfo get() = - serviceHub.networkMapCache.notaryNodes.filter { it.identity == notary }.single() + serviceHub.networkMapCache.notaryNodes.filter { it.identity == payload.notary }.single() } /** * One side of the protocol for inserting a pre-agreed deal. */ - open class Acceptor(override val otherParty: Party, - val notary: Party, - val dealToBuy: T, - override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary() { + open class Acceptor(override val otherParty: Party, + override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary() { - override fun validateHandshake(handshake: Handshake): Handshake { + override fun validateHandshake(handshake: Handshake): Handshake { // What is the seller trying to sell us? - val deal: T = handshake.payload - logger.trace { "Got deal request for: ${handshake.payload.ref}" } - - check(dealToBuy == deal) - - return handshake.copy(payload = deal) - + val autoOffer = handshake.payload + val deal = autoOffer.dealBeingOffered + logger.trace { "Got deal request for: ${deal.ref}" } + return handshake.copy(payload = autoOffer.copy(dealBeingOffered = deal)) } - override fun assembleSharedTX(handshake: Handshake): Pair> { - val ptx = handshake.payload.generateAgreement(notary) + override fun assembleSharedTX(handshake: Handshake): Pair> { + val deal = handshake.payload.dealBeingOffered + val ptx = deal.generateAgreement(handshake.payload.notary) // And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt // to have one. ptx.setTime(serviceHub.clock.instant(), 30.seconds) - return Pair(ptx, arrayListOf(handshake.payload.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey)) + return Pair(ptx, arrayListOf(deal.parties.single { it.name == serviceHub.storageService.myLegalIdentity.name }.owningKey)) } - } /** @@ -314,16 +313,15 @@ object TwoPartyDealProtocol { * who does what in the protocol. */ class Fixer(override val otherParty: Party, - val oracleType: ServiceType, - override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary() { + override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary() { private lateinit var txState: TransactionState<*> private lateinit var deal: FixableDealState - override fun validateHandshake(handshake: Handshake): Handshake { + override fun validateHandshake(handshake: Handshake): Handshake { logger.trace { "Got fixing request for: ${handshake.payload}" } - txState = serviceHub.loadState(handshake.payload) + txState = serviceHub.loadState(handshake.payload.ref) deal = txState.data as FixableDealState // validate the party that initiated is the one on the deal and that the recipient corresponds with it. @@ -336,7 +334,7 @@ object TwoPartyDealProtocol { } @Suspendable - override fun assembleSharedTX(handshake: Handshake): Pair> { + override fun assembleSharedTX(handshake: Handshake): Pair> { @Suppress("UNCHECKED_CAST") val fixOf = deal.nextFixingOf()!! @@ -348,12 +346,12 @@ object TwoPartyDealProtocol { val ptx = TransactionType.General.Builder(txState.notary) - val oracle = serviceHub.networkMapCache.get(oracleType).first() + val oracle = serviceHub.networkMapCache.get(handshake.payload.oracleType).first() val addFixing = object : RatesFixProtocol(ptx, oracle.identity, fixOf, BigDecimal.ZERO, BigDecimal.ONE) { @Suspendable override fun beforeSigning(fix: Fix) { - newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload), fix) + newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload.ref), fix) // And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt // to have one. @@ -373,12 +371,13 @@ object TwoPartyDealProtocol { * does what in the protocol. */ class Floater(override val otherParty: Party, - override val payload: StateRef, - override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() { + override val payload: FixingSession, + override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() { + @Suppress("UNCHECKED_CAST") internal val dealToFix: StateAndRef by TransientProperty { - val state = serviceHub.loadState(payload) as TransactionState - StateAndRef(state, payload) + val state = serviceHub.loadState(payload.ref) as TransactionState + StateAndRef(state, payload.ref) } override val myKeyPair: KeyPair get() { @@ -393,23 +392,18 @@ object TwoPartyDealProtocol { /** Used to set up the session between [Floater] and [Fixer] */ - data class FixingSessionInitiation(val timeout: Duration, - val oracleType: ServiceType, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class FixingSession(val ref: StateRef, val oracleType: ServiceType) /** * This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing. * * It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the - * Fixer role is chosen, then that will be initiated by the [FixingSessionInitiation] message sent from the other party and + * Fixer role is chosen, then that will be initiated by the [FixingSession] message sent from the other party and * handled by the [FixingSessionInitiationHandler]. * - * TODO: Replace [FixingSessionInitiation] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists. + * TODO: Replace [FixingSession] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists. */ class FixingRoleDecider(val ref: StateRef, - val timeout: Duration, override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic() { companion object { @@ -418,8 +412,6 @@ object TwoPartyDealProtocol { fun tracker() = ProgressTracker(LOADING()) } - override val topic: String get() = FIX_INITIATE_TOPIC - @Suspendable override fun call(): Unit { progressTracker.nextStep() @@ -427,17 +419,10 @@ object TwoPartyDealProtocol { // TODO: this is not the eventual mechanism for identifying the parties val fixableDeal = (dealToFix.data as FixableDealState) val sortedParties = fixableDeal.parties.sortedBy { it.name } - val oracleType = fixableDeal.oracleType if (sortedParties[0].name == serviceHub.storageService.myLegalIdentity.name) { - val initation = FixingSessionInitiation( - timeout, - oracleType, - serviceHub.storageService.myLegalIdentity) - // Send initiation to other side to launch one side of the fixing protocol (the Fixer). - send(sortedParties[1], initation) - - // Then start the other side of the fixing protocol. - subProtocol(Floater(sortedParties[1], ref), inheritParentSessions = true) + val fixing = FixingSession(ref, fixableDeal.oracleType) + // Start the Floater which will then kick-off the Fixer + subProtocol(Floater(sortedParties[1], fixing)) } } } diff --git a/core/src/test/java/com/r3corda/core/protocols/ProtocolLogicRefFromJavaTest.java b/core/src/test/java/com/r3corda/core/protocols/ProtocolLogicRefFromJavaTest.java index ca9727b3e7..34959eea4e 100644 --- a/core/src/test/java/com/r3corda/core/protocols/ProtocolLogicRefFromJavaTest.java +++ b/core/src/test/java/com/r3corda/core/protocols/ProtocolLogicRefFromJavaTest.java @@ -1,10 +1,11 @@ package com.r3corda.core.protocols; +import org.junit.Test; -import org.jetbrains.annotations.*; -import org.junit.*; - -import java.util.*; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; public class ProtocolLogicRefFromJavaTest { @@ -33,12 +34,6 @@ public class ProtocolLogicRefFromJavaTest { public Void call() { return null; } - - @NotNull - @Override - protected String getTopic() { - throw new UnsupportedOperationException(); - } } private static class JavaNoArgProtocolLogic extends ProtocolLogic { @@ -50,12 +45,6 @@ public class ProtocolLogicRefFromJavaTest { public Void call() { return null; } - - @NotNull - @Override - protected String getTopic() { - throw new UnsupportedOperationException(); - } } @Test diff --git a/core/src/test/kotlin/com/r3corda/core/protocols/BroadcastTransactionProtocolTest.kt b/core/src/test/kotlin/com/r3corda/core/protocols/BroadcastTransactionProtocolTest.kt index f3d53ed7ad..39019a50a0 100644 --- a/core/src/test/kotlin/com/r3corda/core/protocols/BroadcastTransactionProtocolTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/protocols/BroadcastTransactionProtocolTest.kt @@ -10,28 +10,21 @@ import com.pholser.junit.quickcheck.runner.JUnitQuickcheck import com.r3corda.contracts.testing.SignedTransactionGenerator import com.r3corda.core.serialization.createKryo import com.r3corda.core.serialization.serialize -import com.r3corda.core.testing.PartyGenerator -import com.r3corda.protocols.BroadcastTransactionProtocol +import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest import org.junit.runner.RunWith import kotlin.test.assertEquals @RunWith(JUnitQuickcheck::class) class BroadcastTransactionProtocolTest { - class NotifyTxRequestMessageGenerator : Generator(BroadcastTransactionProtocol.NotifyTxRequestMessage::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): BroadcastTransactionProtocol.NotifyTxRequestMessage { - return BroadcastTransactionProtocol.NotifyTxRequestMessage( - tx = SignedTransactionGenerator().generate(random, status), - events = setOf(), - replyToParty = PartyGenerator().generate(random, status), - sendSessionID = random.nextLong(), - receiveSessionID = random.nextLong() - ) + class NotifyTxRequestMessageGenerator : Generator(NotifyTxRequest::class.java) { + override fun generate(random: SourceOfRandomness, status: GenerationStatus): NotifyTxRequest { + return NotifyTxRequest(tx = SignedTransactionGenerator().generate(random, status), events = setOf()) } } @Property - fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: BroadcastTransactionProtocol.NotifyTxRequestMessage) { + fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: NotifyTxRequest) { val kryo = createKryo() val serialized = message.serialize().bits val deserialized = kryo.readClassAndObject(Input(serialized)) diff --git a/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt b/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt index 1ee3831b92..976b82feb2 100644 --- a/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt @@ -23,18 +23,15 @@ class ProtocolLogicRefTest { constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b")) override fun call() = Unit - override val topic: String get() = throw UnsupportedOperationException() } class KotlinNoArgProtocolLogic : ProtocolLogic() { override fun call() = Unit - override val topic: String get() = throw UnsupportedOperationException() } @Suppress("UNUSED_PARAMETER") // We will never use A or b class NotWhiteListedKotlinProtocolLogic(A: Int, b: String) : ProtocolLogic() { override fun call() = Unit - override val topic: String get() = throw UnsupportedOperationException() } lateinit var factory: ProtocolLogicRefFactory diff --git a/docs/source/protocol-state-machines.rst b/docs/source/protocol-state-machines.rst index b5bce9a70b..1da9d0f7cc 100644 --- a/docs/source/protocol-state-machines.rst +++ b/docs/source/protocol-state-machines.rst @@ -91,7 +91,7 @@ Our protocol has two parties (B and S for buyer and seller) and will proceed as it lacks a signature from S authorising movement of the asset. 3. S signs it and hands the now finalised ``SignedTransaction`` back to B. -You can find the implementation of this protocol in the file ``contracts/protocols/TwoPartyTradeProtocol.kt``. +You can find the implementation of this protocol in the file ``contracts/src/main/kotlin/com/r3corda/protocols/TwoPartyTradeProtocol.kt``. Assuming no malicious termination, they both end the protocol being in posession of a valid, signed transaction that represents an atomic asset swap. @@ -110,7 +110,6 @@ each side. .. sourcecode:: kotlin object TwoPartyTradeProtocol { - val TOPIC = "platform.trade" class UnacceptablePriceException(val givenPrice: Amount) : Exception("Unacceptable price: $givenPrice") class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { @@ -118,21 +117,20 @@ each side. } // This object is serialised to the network and is the first protocol message the seller sends to the buyer. - class SellerTradeInfo( + data class SellerTradeInfo( val assetForSale: StateAndRef, - val price: Amount, - val sellerOwnerKey: PublicKey, - val sessionID: Long + val price: Amount, + val sellerOwnerKey: PublicKey ) - class SignaturesFromSeller(val timestampAuthoritySig: DigitalSignature.WithKey, val sellerSig: DigitalSignature.WithKey) + data class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, + val notarySig: DigitalSignature.LegallyIdentifiable) open class Seller(val otherSide: Party, val notaryNode: NodeInfo, val assetToSell: StateAndRef, val price: Amount, val myKeyPair: KeyPair, - val buyerSessionID: Long, override val progressTracker: ProgressTracker = Seller.tracker()) : ProtocolLogic() { @Suspendable override fun call(): SignedTransaction { @@ -143,8 +141,7 @@ each side. open class Buyer(val otherSide: Party, val notary: Party, val acceptablePrice: Amount, - val typeToBuy: Class, - val sessionID: Long) : ProtocolLogic() { + val typeToBuy: Class) : ProtocolLogic() { @Suspendable override fun call(): SignedTransaction { TODO() @@ -152,25 +149,17 @@ each side. } } -Let's unpack what this code does: - -- It defines a several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes - are simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol. -- It defines the "trade topic", which is just a string that namespaces this protocol. The prefix "platform." is reserved - by Corda, but you can define your own protocol namespaces using standard Java-style reverse DNS notation. +This code defines several classes nested inside the main ``TwoPartyTradeProtocol`` singleton. Some of the classes are +simply protocol messages or exceptions. The other two represent the buyer and seller side of the protocol. Going through the data needed to become a seller, we have: -- ``otherSide: SingleMessageRecipient`` - the network address of the node with which you are trading. +- ``otherSide: Party`` - the party with which you are trading. - ``notaryNode: NodeInfo`` - the entry in the network map for the chosen notary. See ":doc:`consensus`" for more information on notaries. - ``assetToSell: StateAndRef`` - a pointer to the ledger entry that represents the thing being sold. - ``price: Amount`` - the agreed on price that the asset is being sold for (without an issuer constraint). - ``myKeyPair: KeyPair`` - the key pair that controls the asset being sold. It will be used to sign the transaction. -- ``buyerSessionID: Long`` - a unique number that identifies this trade to the buyer. It is expected that the buyer - knows that the trade is going to take place and has sent you such a number already. - -.. note:: Session IDs will be automatically handled in a future version of the framework. And for the buyer: @@ -178,7 +167,6 @@ And for the buyer: a price less than or equal to this, then the trade will go ahead. - ``typeToBuy: Class`` - the type of state that is being purchased. This is used to check that the sell side of the protocol isn't trying to sell us the wrong thing, whether by accident or on purpose. -- ``sessionID: Long`` - the session ID that was handed to the seller in order to start the protocol. Alright, so using this protocol shouldn't be too hard: in the simplest case we can just create a Buyer or Seller with the details of the trade, depending on who we are. We then have to start the protocol in some way. Just @@ -221,6 +209,27 @@ protocol are checked against a whitelist, which can be extended by apps themselv The process of starting a protocol returns a ``ListenableFuture`` that you can use to either block waiting for the result, or register a callback that will be invoked when the result is ready. +In a two party protocol only one side is to be manually started using ``ServiceHub.invokeProtocolAsync``. The other side +has to be registered by its node to respond to the initiating protocol via ``ServiceHubInternal.registerProtocolInitiator``. +In our example it doesn't matter which protocol is the initiator and which is the initiated. For example, if we are to +take the seller as the initiator then we would register the buyer as such: + +.. container:: codeset + + .. sourcecode:: kotlin + + val services: ServiceHubInternal = TODO() + + services.registerProtocolInitiator(Seller::class) { otherParty -> + val notary = services.networkMapCache.notaryNodes[0] + val acceptablePrice = TODO() + val typeToBuy = TODO() + Buyer(otherParty, notary, acceptablePrice, typeToBuy) + } + +This is telling the buyer node to fire up an instance of ``Buyer`` (the code in the lambda) when the initiating protocol +is a seller (``Seller::class``). + Implementing the seller ----------------------- @@ -253,12 +262,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method. @Suspendable private fun receiveAndCheckProposedTransaction(): SignedTransaction { - val sessionID = random63BitValue() - // Make the first message we'll send to kick off the protocol. - val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID) + val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public) - val maybeSTX = sendAndReceive(otherSide, buyerSessionID, sessionID, hello) + val maybeSTX = sendAndReceive(otherSide, hello) maybeSTX.unwrap { // Check that the tx proposed by the buyer is valid. @@ -281,11 +288,10 @@ Let's fill out the ``receiveAndCheckProposedTransaction()`` method. } } -Let's break this down. We generate a session ID to identify what's happening on the seller side, fill out -the initial protocol message, and then call ``sendAndReceive``. This function takes a few arguments: +Let's break this down. We fill out the initial protocol message with the trade info, and then call ``sendAndReceive``. +This function takes a few arguments: -- The topic string that ensures the message is routed to the right bit of code in the other side's node. -- The session IDs that ensure the messages don't get mixed up with other simultaneous trades. +- The party on the other side. - The thing to send. It'll be serialised and sent automatically. - Finally a type argument, which is the kind of object we're expecting to receive from the other side. If we get back something else an exception is thrown. @@ -370,7 +376,7 @@ Here's the rest of the code: notarySignature: DigitalSignature.LegallyIdentifiable): SignedTransaction { val fullySigned = partialTX + ourSignature + notarySignature logger.trace { "Built finished transaction, sending back to secondary!" } - send(otherSide, buyerSessionID, SignaturesFromSeller(ourSignature, notarySignature)) + send(otherSide, SignaturesFromSeller(ourSignature, notarySignature)) return fullySigned } @@ -406,7 +412,7 @@ OK, let's do the same for the buyer side: val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest) val stx = signWithOurKeys(cashSigningPubKeys, ptx) - val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID) + val signatures = swapSignaturesWithSeller(stx) logger.trace { "Got signatures from seller, verifying ... " } @@ -419,16 +425,14 @@ OK, let's do the same for the buyer side: @Suspendable private fun receiveAndValidateTradeRequest(): SellerTradeInfo { - // Wait for a trade request to come in on our pre-provided session ID. - val maybeTradeRequest = receive(sessionID) + // Wait for a trade request to come in from the other side + val maybeTradeRequest = receive(otherParty) maybeTradeRequest.unwrap { // What is the seller trying to sell us? val asset = it.assetForSale.state.data val assetTypeName = asset.javaClass.name logger.trace { "Got trade request for a $assetTypeName: ${it.assetForSale}" } - // Check the start message for acceptability. - check(it.sessionID > 0) if (it.price > acceptablePrice) throw UnacceptablePriceException(it.price) if (!typeToBuy.isInstance(asset)) @@ -443,13 +447,13 @@ OK, let's do the same for the buyer side: } @Suspendable - private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller { + private fun swapSignaturesWithSeller(stx: SignedTransaction): SignaturesFromSeller { progressTracker.currentStep = SWAPPING_SIGNATURES logger.trace { "Sending partially signed transaction to seller" } // TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx. - return sendAndReceive(otherSide, theirSessionID, sessionID, stx).unwrap { it } + return sendAndReceive(otherSide, stx).unwrap { it } } private fun signWithOurKeys(cashSigningPubKeys: List, ptx: TransactionBuilder): SignedTransaction { @@ -676,7 +680,6 @@ Future features The protocol framework is a key part of the platform and will be extended in major ways in future. Here are some of the features we have planned: -* Automatic session ID management * Identity based addressing * Exposing progress trackers to local (inside the firewall) clients using message queues and/or WebSockets * Exception propagation and management, with a "protocol hospital" tool to manually provide solutions to unavoidable diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 5fc67a15a0..82a7f4e8a2 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -24,6 +24,7 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.serialize import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.utilities.debug import com.r3corda.node.api.APIServer import com.r3corda.node.services.api.* import com.r3corda.node.services.config.NodeConfiguration @@ -54,8 +55,10 @@ import java.nio.file.Path import java.security.KeyPair import java.time.Clock import java.util.* +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService import java.util.concurrent.TimeUnit +import kotlin.reflect.KClass /** * A base node implementation that can be customised either for production (with real implementations that do real @@ -91,6 +94,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap protected val _servicesThatAcceptUploads = ArrayList() val servicesThatAcceptUploads: List = _servicesThatAcceptUploads + private val protocolFactories = ConcurrentHashMap, (Party) -> ProtocolLogic<*>>() + val services = object : ServiceHubInternal() { override val networkService: MessagingServiceInternal get() = net override val networkMapCache: NetworkMapCache get() = netMapCache @@ -109,6 +114,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap return smm.add(loggerName, logic).resultFuture } + override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) { + require(markerClass !in protocolFactories) { "${markerClass.java.name} has already been used to register a protocol" } + log.debug { "Registering ${markerClass.java.name}" } + protocolFactories[markerClass.java] = protocolFactory + } + + override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? { + return protocolFactories[markerClass] + } + override fun recordTransactions(txs: Iterable) = recordTransactionsInternal(storage, txs) } diff --git a/node/src/main/kotlin/com/r3corda/node/services/NotaryChangeService.kt b/node/src/main/kotlin/com/r3corda/node/services/NotaryChangeService.kt index c9753a6cf3..b6e89211b8 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/NotaryChangeService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/NotaryChangeService.kt @@ -1,11 +1,9 @@ package com.r3corda.node.services import com.r3corda.core.node.CordaPluginRegistry -import com.r3corda.node.services.api.AbstractNodeService +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.protocols.AbstractStateReplacementProtocol import com.r3corda.protocols.NotaryChangeProtocol -import com.r3corda.protocols.NotaryChangeProtocol.TOPIC object NotaryChange { class Plugin : CordaPluginRegistry() { @@ -16,11 +14,9 @@ object NotaryChange { * A service that monitors the network for requests for changing the notary of a state, * and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met. */ - class Service(services: ServiceHubInternal) : AbstractNodeService(services) { + class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() { init { - addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake -> - NotaryChangeProtocol.Acceptor(req.replyToParty) - } + services.registerProtocolInitiator(NotaryChangeProtocol.Instigator::class) { NotaryChangeProtocol.Acceptor(it) } } } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt b/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt index badc92a009..9952728f03 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt @@ -1,16 +1,12 @@ package com.r3corda.node.services.api -import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.messaging.Message import com.r3corda.core.messaging.MessageHandlerRegistration import com.r3corda.core.messaging.createMessage import com.r3corda.core.node.services.DEFAULT_SESSION_ID -import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.serialize -import com.r3corda.core.utilities.loggerFor -import com.r3corda.protocols.HandshakeMessage import com.r3corda.protocols.ServiceRequestMessage import javax.annotation.concurrent.ThreadSafe @@ -20,10 +16,6 @@ import javax.annotation.concurrent.ThreadSafe @ThreadSafe abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() { - companion object { - val logger = loggerFor() - } - val net: MessagingServiceInternal get() = services.networkService /** @@ -68,36 +60,4 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception }) } - /** - * Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the - * necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession. - * @param topic the topic on which the handshake is sent from the other party - * @param loggerName the logger name to use when starting the protocol - * @param protocolFactory a function to create the protocol with the given handshake message - * @param onResultFuture provides access to the [ListenableFuture] when the protocol starts - */ - protected inline fun addProtocolHandler( - topic: String, - loggerName: String, - crossinline protocolFactory: (H) -> ProtocolLogic, - crossinline onResultFuture: ProtocolLogic.(ListenableFuture, H) -> Unit) { - net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg -> - try { - val handshake = message.data.deserialize() - val protocol = protocolFactory(handshake) - protocol.registerSession(handshake) - val resultFuture = services.startProtocol(loggerName, protocol) - protocol.onResultFuture(resultFuture, handshake) - } catch (e: Exception) { - logger.error("Unable to process ${H::class.java.name} message", e) - } - } - } - - protected inline fun addProtocolHandler( - topic: String, - loggerName: String, - crossinline protocolFactory: (H) -> ProtocolLogic) { - addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> }) - } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index cc8d180a53..7693cd770c 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -1,7 +1,7 @@ package com.r3corda.node.services.api +import com.r3corda.core.crypto.SecureHash import com.r3corda.core.serialization.SerializedBytes -import com.r3corda.node.services.statemachine.ProtocolIORequest import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl /** @@ -30,14 +30,13 @@ interface CheckpointStorage { } // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). -data class Checkpoint( - val serialisedFiber: SerializedBytes>, - val request: ProtocolIORequest?, - val receivedPayload: Any? -) { - // This flag is always false when loaded from storage as it isn't serialised. - // It is used to track when the associated fiber has been created, but not necessarily started when - // messages for protocols arrive before the system has fully loaded at startup. - @Transient - var fiberCreated: Boolean = false -} \ No newline at end of file +class Checkpoint(val serialisedFiber: SerializedBytes>) { + + val id: SecureHash get() = serialisedFiber.hash + + override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id + + override fun hashCode(): Int = id.hashCode() + + override fun toString(): String = "${javaClass.simpleName}(id=$id)" +} diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt index 05758f4908..fa573b3514 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt @@ -1,14 +1,16 @@ package com.r3corda.node.services.api import com.google.common.util.concurrent.ListenableFuture -import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.MessagingService import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogicRefFactory +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import org.slf4j.LoggerFactory +import kotlin.reflect.KClass interface MessagingServiceInternal : MessagingService { /** @@ -49,7 +51,7 @@ abstract class ServiceHubInternal : ServiceHub { * @param txs The transactions to record. */ internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable) { - val stateMachineRunId = ProtocolStateMachineImpl.retrieveCurrentStateMachine()?.id + val stateMachineRunId = ProtocolStateMachineImpl.currentStateMachine()?.id if (stateMachineRunId != null) { txs.forEach { storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id) @@ -68,6 +70,23 @@ abstract class ServiceHubInternal : ServiceHub { */ abstract fun startProtocol(loggerName: String, logic: ProtocolLogic): ListenableFuture + /** + * Register the protocol factory we wish to use when a initiating party attempts to communicate with us. The + * registration is done against a marker [KClass] which is sent in the session handsake by the other party. If this + * marker class has been registered then the corresponding factory will be used to create the protocol which will + * communicate with the other side. If there is no mapping then the session attempt is rejected. + * @param markerClass The marker [KClass] present in a session initiation attempt, which is a 1:1 mapping to a [Class] + * using the
::class
construct. Any marker class can be used, with the default being the class of the initiating + * protocol. This enables the registration to be of the form: registerProtocolInitiator(InitiatorProtocol::class, ::InitiatedProtocol) + * @param protocolFactory The protocol factory generating the initiated protocol. + */ + abstract fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) + + /** + * Return the protocol factory that has been registered with [markerClass], or null if no factory is found. + */ + abstract fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? + override fun invokeProtocolAsync(logicType: Class>, vararg args: Any?): ListenableFuture { val logicRef = protocolLogicRefFactory.create(logicType, *args) @Suppress("UNCHECKED_CAST") diff --git a/node/src/main/kotlin/com/r3corda/node/services/clientapi/FixingSessionInitiation.kt b/node/src/main/kotlin/com/r3corda/node/services/clientapi/FixingSessionInitiation.kt index 90686543d7..8e3cf7f591 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/clientapi/FixingSessionInitiation.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/clientapi/FixingSessionInitiation.kt @@ -1,28 +1,25 @@ package com.r3corda.node.services.clientapi import com.r3corda.core.node.CordaPluginRegistry -import com.r3corda.node.services.api.AbstractNodeService +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.protocols.TwoPartyDealProtocol -import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC -import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation +import com.r3corda.protocols.TwoPartyDealProtocol.Fixer +import com.r3corda.protocols.TwoPartyDealProtocol.Floater /** * This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of * running scheduled fixings for the [InterestRateSwap] contract. * - * TODO: This will be replaced with the automatic sessionID / session setup work. + * TODO: This will be replaced with the symmetric session work */ object FixingSessionInitiation { class Plugin: CordaPluginRegistry() { override val servicePlugins: List> = listOf(Service::class.java) } - class Service(services: ServiceHubInternal) : AbstractNodeService(services) { + class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() { init { - addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation -> - TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType) - } + services.registerProtocolInitiator(Floater::class) { Fixer(it) } } } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/monitor/NodeMonitorService.kt b/node/src/main/kotlin/com/r3corda/node/services/monitor/NodeMonitorService.kt index 508ed471ba..acea7d6148 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/monitor/NodeMonitorService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/monitor/NodeMonitorService.kt @@ -169,7 +169,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana val tx = builder.toSignedTransaction(checkSufficientSignatures = false) val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient)) return TransactionBuildResult.ProtocolStarted( - smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, + smm.add("broadcast", protocol).id, tx, "Cash payment transaction generated" ) @@ -203,7 +203,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana val tx = builder.toSignedTransaction(checkSufficientSignatures = false) val protocol = FinalityProtocol(tx, setOf(req), participants) return TransactionBuildResult.ProtocolStarted( - smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, + smm.add("broadcast", protocol).id, tx, "Cash destruction transaction generated" ) @@ -222,7 +222,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana // Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient)) return TransactionBuildResult.ProtocolStarted( - smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id, + smm.add("broadcast", protocol).id, tx, "Cash issuance completed" ) diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/DataVendingService.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/DataVendingService.kt index f5b925a4ef..6ed9dc5208 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/persistence/DataVendingService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/DataVendingService.kt @@ -1,17 +1,12 @@ package com.r3corda.node.services.persistence +import co.paralleluniverse.fibers.Suspendable import com.r3corda.core.crypto.Party -import com.r3corda.core.failure -import com.r3corda.core.messaging.MessagingService -import com.r3corda.core.messaging.TopicSession import com.r3corda.core.node.CordaPluginRegistry -import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.recordTransactions -import com.r3corda.core.serialization.serialize -import com.r3corda.core.success -import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.utilities.loggerFor -import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.protocols.* import java.io.InputStream @@ -39,78 +34,73 @@ object DataVending { // TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because // the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both // should be fixed at the same time. - class Service(services: ServiceHubInternal) : AbstractNodeService(services) { + class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() { + companion object { val logger = loggerFor() - - /** - * Notify a node of a transaction. Normally any notarisation required would happen before this is called. - */ - fun notify(net: MessagingService, - myIdentity: Party, - recipient: NodeInfo, - transaction: SignedTransaction) { - val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity) - net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address) - } } - val storage = services.storageService - class TransactionRejectedError(msg: String) : Exception(msg) init { - addMessageHandler(FetchTransactionsProtocol.TOPIC, - { req: FetchDataProtocol.Request -> handleTXRequest(req) }, - { message, e -> logger.error("Failure processing data vending request.", e) } - ) - - addMessageHandler(FetchAttachmentsProtocol.TOPIC, - { req: FetchDataProtocol.Request -> handleAttachmentRequest(req) }, - { message, e -> logger.error("Failure processing data vending request.", e) } - ) - - // TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction - // includes us in any outside that list. Potentially just if it includes any outside that list at all. - // TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming - // cash without from unknown parties? - addProtocolHandler( - BroadcastTransactionProtocol.TOPIC, - "Resolving transactions", - { req: BroadcastTransactionProtocol.NotifyTxRequestMessage -> - ResolveTransactionsProtocol(req.tx, req.replyToParty) - }, - { future, req -> - future.success { - serviceHub.recordTransactions(req.tx) - }.failure { throwable -> - logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable) - } - }) + services.registerProtocolInitiator(FetchTransactionsProtocol::class, ::FetchTransactionsHandler) + services.registerProtocolInitiator(FetchAttachmentsProtocol::class, ::FetchAttachmentsHandler) + services.registerProtocolInitiator(BroadcastTransactionProtocol::class, ::NotifyTransactionHandler) } - private fun handleTXRequest(req: FetchDataProtocol.Request): List { - require(req.hashes.isNotEmpty()) - return req.hashes.map { - val tx = storage.validatedTransactions.getTransaction(it) - if (tx == null) - logger.info("Got request for unknown tx $it") - tx + + private class FetchTransactionsHandler(val otherParty: Party) : ProtocolLogic() { + @Suspendable + override fun call() { + val request = receive(otherParty).unwrap { + require(it.hashes.isNotEmpty()) + it + } + val txs = request.hashes.map { + val tx = serviceHub.storageService.validatedTransactions.getTransaction(it) + if (tx == null) + logger.info("Got request for unknown tx $it") + tx + } + send(otherParty, txs) } } - private fun handleAttachmentRequest(req: FetchDataProtocol.Request): List { - // TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer. - require(req.hashes.isNotEmpty()) - return req.hashes.map { - val jar: InputStream? = storage.attachments.openAttachment(it)?.open() - if (jar == null) { - logger.info("Got request for unknown attachment $it") - null - } else { - jar.readBytes() + + // TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer. + private class FetchAttachmentsHandler(val otherParty: Party) : ProtocolLogic() { + @Suspendable + override fun call() { + val request = receive(otherParty).unwrap { + require(it.hashes.isNotEmpty()) + it } + val attachments = request.hashes.map { + val jar: InputStream? = serviceHub.storageService.attachments.openAttachment(it)?.open() + if (jar == null) { + logger.info("Got request for unknown attachment $it") + null + } else { + jar.readBytes() + } + } + send(otherParty, attachments) + } + } + + + // TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction + // includes us in any outside that list. Potentially just if it includes any outside that list at all. + // TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming + // cash without from unknown parties? + class NotifyTransactionHandler(val otherParty: Party) : ProtocolLogic() { + @Suspendable + override fun call() { + val request = receive(otherParty).unwrap { it } + subProtocol(ResolveTransactionsProtocol(request.tx, otherParty), shareParentSessions = true) + serviceHub.recordTransactions(request.tx) } } } + } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt index 7df7f68f3a..a0028fd7a2 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt @@ -1,53 +1,38 @@ package com.r3corda.node.services.statemachine -import com.r3corda.core.crypto.Party -import com.r3corda.core.messaging.TopicSession -import java.util.* +import com.r3corda.node.services.statemachine.StateMachineManager.ProtocolSession +import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage // TODO revisit when Kotlin 1.1 is released and data classes can extend other classes interface ProtocolIORequest { // This is used to identify where we suspended, in case of message mismatch errors and other things where we // don't have the original stack trace because it's in a suspended fiber. val stackTraceInCaseOfProblems: StackSnapshot - val topic: String + val session: ProtocolSession } interface SendRequest : ProtocolIORequest { - val destination: Party - val payload: Any - val sendSessionID: Long - val uniqueMessageId: UUID + val message: SessionMessage } -interface ReceiveRequest : ProtocolIORequest { +interface ReceiveRequest : ProtocolIORequest { val receiveType: Class - val receiveSessionID: Long - val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID) } -data class SendAndReceive(override val topic: String, - override val destination: Party, - override val payload: Any, - override val sendSessionID: Long, - override val uniqueMessageId: UUID, - override val receiveType: Class, - override val receiveSessionID: Long) : SendRequest, ReceiveRequest { +data class SendAndReceive(override val session: ProtocolSession, + override val message: SessionMessage, + override val receiveType: Class) : SendRequest, ReceiveRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class ReceiveOnly(override val topic: String, - override val receiveType: Class, - override val receiveSessionID: Long) : ReceiveRequest { +data class ReceiveOnly(override val session: ProtocolSession, + override val receiveType: Class) : ReceiveRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class SendOnly(override val destination: Party, - override val topic: String, - override val payload: Any, - override val sendSessionID: Long, - override val uniqueMessageId: UUID) : SendRequest { +data class SendOnly(override val session: ProtocolSession, override val message: SessionMessage) : SendRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 1e6248848b..ef5a07a176 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -8,16 +8,22 @@ import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.crypto.Party import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.protocols.ProtocolSessionException import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.StateMachineRunId +import com.r3corda.core.random63BitValue +import com.r3corda.core.rootCause import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.ServiceHubInternal +import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.utilities.createDatabaseTransaction import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.transactions.TransactionManager import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.io.PrintWriter +import java.io.StringWriter import java.sql.SQLException import java.util.* import java.util.concurrent.ExecutionException @@ -36,12 +42,26 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, private val loggerName: String) : Fiber("protocol", scheduler), ProtocolStateMachine { + companion object { + // Used to work around a small limitation in Quasar. + private val QUASAR_UNBLOCKER = run { + val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER") + field.isAccessible = true + field.get(null) + } + + /** + * Return the current [ProtocolStateMachineImpl] or null if executing outside of one. + */ + fun currentStateMachine(): ProtocolStateMachineImpl<*>? = Strand.currentStrand() as? ProtocolStateMachineImpl<*> + } + // These fields shouldn't be serialised, so they are marked @Transient. @Transient lateinit override var serviceHub: ServiceHubInternal - @Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit + @Transient internal lateinit var actionOnSuspend: (ProtocolIORequest) -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit - @Transient internal var receivedPayload: Any? = null @Transient internal lateinit var database: Database + @Transient internal var fromCheckpoint: Boolean = false @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -62,18 +82,20 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, } } + internal val openSessions = HashMap, Party>, ProtocolSession>() + init { logic.psm = this + name = id.toString() } - @Suspendable @Suppress("UNCHECKED_CAST") + @Suspendable override fun run(): R { createTransaction() val result = try { logic.call() } catch (t: Throwable) { - actionOnEnd() - _resultFuture?.setException(t) + processException(t) commitTransaction() throw ExecutionException(t) } @@ -106,56 +128,140 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): UntrustworthyData { - suspend(receiveRequest) - check(receivedPayload != null) { "Expected to receive something" } - val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload)) - receivedPayload = null - return untrustworthy - } - - @Suspendable - override fun sendAndReceive(topic: String, - destination: Party, - sessionIDForSend: Long, - sessionIDForReceive: Long, + override fun sendAndReceive(otherParty: Party, payload: Any, - receiveType: Class): UntrustworthyData { - return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, UUID.randomUUID(), receiveType, sessionIDForReceive)) + receiveType: Class, + sessionProtocol: ProtocolLogic<*>): UntrustworthyData { + val session = getSession(otherParty, sessionProtocol) + val sendSessionData = createSessionData(session, payload) + val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java) + return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) } @Suspendable - override fun receive(topic: String, sessionIDForReceive: Long, receiveType: Class): UntrustworthyData { - return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive)) + override fun receive(otherParty: Party, + receiveType: Class, + sessionProtocol: ProtocolLogic<*>): UntrustworthyData { + val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java) + return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) } @Suspendable - override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { - suspend(SendOnly(destination, topic, payload, sessionID, UUID.randomUUID())) + override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) { + val session = getSession(otherParty, sessionProtocol) + val sendSessionData = createSessionData(session, payload) + sendInternal(session, sendSessionData) + } + + private fun createSessionData(session: ProtocolSession, payload: Any): SessionData { + val otherPartySessionId = session.otherPartySessionId + ?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session") + return SessionData(otherPartySessionId, payload) } @Suspendable - private fun suspend(protocolIORequest: ProtocolIORequest) { + private fun sendInternal(session: ProtocolSession, message: SessionMessage) { + suspend(SendOnly(session, message)) + } + + @Suspendable + private fun receiveInternal(session: ProtocolSession, receiveType: Class): T { + return suspendAndExpectReceive(ReceiveOnly(session, receiveType)) + } + + @Suspendable + private fun sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class): T { + return suspendAndExpectReceive(SendAndReceive(session, message, receiveType)) + } + + @Suspendable + private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession { + return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol) + } + + @Suspendable + private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession { + val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null) + openSessions[Pair(sessionProtocol, otherParty)] = session + val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name + val sessionInit = SessionInit(session.ourSessionId, serviceHub.storageService.myLegalIdentity, counterpartyProtocol) + val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java) + if (sessionInitResponse is SessionConfirm) { + session.otherPartySessionId = sessionInitResponse.initiatedSessionId + return session + } else { + sessionInitResponse as SessionReject + throw ProtocolSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}") + } + } + + @Suspendable + private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): T { + fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll() + + val receivedMessage = getReceivedMessage() ?: run { + // Suspend while we wait for the receive + receiveRequest.session.waitingForResponse = true + suspend(receiveRequest) + receiveRequest.session.waitingForResponse = false + getReceivedMessage() + ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $id $receiveRequest") + } + + if (receivedMessage is SessionEnd) { + openSessions.values.remove(receiveRequest.session) + throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended") + } else if (receiveRequest.receiveType.isInstance(receivedMessage)) { + return receiveRequest.receiveType.cast(receivedMessage) + } else { + throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $id $receiveRequest") + } + } + + @Suspendable + private fun suspend(ioRequest: ProtocolIORequest) { commitTransaction() parkAndSerialize { fiber, serializer -> + logger.trace { "Suspended $id on $ioRequest" } try { - suspendAction(protocolIORequest) + actionOnSuspend(ioRequest) } catch (t: Throwable) { // Do not throw exception again - Quasar completely bins it. logger.warn("Captured exception which was swallowed by Quasar", t) - actionOnEnd() - _resultFuture?.setException(t) + // TODO When error handling is introduced, look into whether we should be deleting the checkpoint and + // completing the Future + processException(t) } } createTransaction() } - companion object { - /** - * Retrieves our state machine id if we are running a [ProtocolStateMachineImpl]. - */ - fun retrieveCurrentStateMachine(): ProtocolStateMachineImpl<*>? { - return Strand.currentStrand() as? ProtocolStateMachineImpl<*> + private fun processException(t: Throwable) { + actionOnEnd() + _resultFuture?.setException(t) + } + + internal fun resume(scheduler: FiberScheduler) { + try { + if (fromCheckpoint) { + logger.info("$id resumed from checkpoint") + fromCheckpoint = false + Fiber.unparkDeserialized(this, scheduler) + } else if (state == State.NEW) { + logger.trace { "$id started" } + start() + } else { + logger.trace { "$id resumed" } + Fiber.unpark(this, QUASAR_UNBLOCKER) + } + } catch (t: Throwable) { + logger.error("$id threw '${t.rootCause}'") + logger.trace { + val s = StringWriter() + t.rootCause.printStackTrace(PrintWriter(s)) + "Stack trace of protocol error: $s" + } } } + } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index bd85ca7235..f61a5a7086 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -3,34 +3,38 @@ package com.r3corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberExecutorScheduler import co.paralleluniverse.io.serialization.kryo.KryoSerializer +import co.paralleluniverse.strands.Strand import com.codahale.metrics.Gauge import com.esotericsoftware.kryo.Kryo -import com.google.common.base.Throwables import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.ThreadBox import com.r3corda.core.abbreviate +import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.TopicSession -import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.send import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.StateMachineRunId +import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.* import com.r3corda.core.then import com.r3corda.core.utilities.ProgressTracker +import com.r3corda.core.utilities.debug +import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.Checkpoint import com.r3corda.node.services.api.CheckpointStorage import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor +import kotlinx.support.jdk8.collections.removeIf import org.jetbrains.exposed.sql.Database import rx.Observable import rx.subjects.PublishSubject import rx.subjects.UnicastSubject -import java.io.PrintWriter -import java.io.StringWriter import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ExecutionException import javax.annotation.concurrent.ThreadSafe @@ -48,7 +52,6 @@ import javax.annotation.concurrent.ThreadSafe * The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually * starts them via [add]. * - * TODO: Session IDs should be set up and propagated automatically, on demand. * TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised * continuation is always unique? * TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk @@ -58,12 +61,19 @@ import javax.annotation.concurrent.ThreadSafe * TODO: Implement stub/skel classes that provide a basic RPC framework on top of this. */ @ThreadSafe -class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableServices: List, +class StateMachineManager(val serviceHub: ServiceHubInternal, + tokenizableServices: List, val checkpointStorage: CheckpointStorage, val executor: AffinityExecutor, val database: Database) { + inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) + companion object { + private val logger = loggerFor() + internal val sessionTopic = TopicSession("platform.session") + } + val scheduler = FiberScheduler() data class Change( @@ -95,6 +105,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private val totalStartedProtocols = metrics.counter("Protocols.Started") private val totalFinishedProtocols = metrics.counter("Protocols.Finished") + private val openSessions = ConcurrentHashMap() + private val recentlyClosedSessions = ConcurrentHashMap() + // Context for tokenized services in checkpoints private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo()) @@ -119,6 +132,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val changes: Observable get() = mutex.content.changesPublisher + init { + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable) + } + } + + fun start() { + restoreFibersFromCheckpoints() + serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() } + } + /** * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * calls to [allStateMachines] @@ -131,69 +155,99 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } - // Used to work around a small limitation in Quasar. - private val QUASAR_UNBLOCKER = run { - val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER") - field.isAccessible = true - field.get(null) - } - - init { - Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> - (fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable) - } - } - - fun start() { - checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) } - serviceHub.networkMapCache.mapServiceRegistered.then(executor) { - mutex.locked { - started = true - stateMachines.forEach { restartFiber(it.key, it.value) } - } - } - } - - private fun createFiberForCheckpoint(checkpoint: Checkpoint) { - if (!checkpoint.fiberCreated) { - val fiber = deserializeFiber(checkpoint.serialisedFiber) - initFiber(fiber, { checkpoint }) - } - } - - private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) { - if (checkpoint.request is ReceiveRequest<*>) { - val topicSession = checkpoint.request.receiveTopicSession - fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession") - iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) { - try { - Fiber.unparkDeserialized(fiber, scheduler) - } catch (e: Throwable) { - logError(e, it, topicSession, fiber) + private fun restoreFibersFromCheckpoints() { + mutex.locked { + checkpointStorage.checkpoints.forEach { + // If a protocol is added before start() then don't attempt to restore it + if (!stateMachines.containsValue(it)) { + val fiber = deserializeFiber(it.serialisedFiber) + initFiber(fiber) + stateMachines[fiber] = it } } - if (checkpoint.request is SendRequest) { - sendMessage(fiber, checkpoint.request) + } + } + + private fun resumeRestoredFibers() { + mutex.locked { + started = true + stateMachines.keys.forEach { resumeRestoredFiber(it) } + } + serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg -> + executor.checkOnThread() + val sessionMessage = message.data.deserialize() + when (sessionMessage) { + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage) + is SessionInit -> onSessionInit(sessionMessage) + } + } + } + + private fun resumeRestoredFiber(fiber: ProtocolStateMachineImpl<*>) { + fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } + if (fiber.openSessions.values.any { it.waitingForResponse }) { + fiber.logger.info("Restored fiber pending on receive ${fiber.id}}") + } else { + resumeFiber(fiber) + } + } + + private fun onExistingSessionMessage(message: ExistingSessionMessage) { + val session = openSessions[message.recipientSessionId] + if (session != null) { + session.psm.logger.trace { "${session.psm.id} received $message on $session" } + if (message is SessionEnd) { + openSessions.remove(message.recipientSessionId) + } + session.receivedMessages += message + if (session.waitingForResponse) { + updateCheckpoint(session.psm) + resumeFiber(session.psm) } } else { - fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}") - executor.executeASAP { - if (checkpoint.request is SendRequest) { - sendMessage(fiber, checkpoint.request) - } - iterateStateMachine(fiber, checkpoint.receivedPayload) { - try { - Fiber.unparkDeserialized(fiber, scheduler) - } catch (e: Throwable) { - logError(e, it, null, fiber) - } + val otherParty = recentlyClosedSessions.remove(message.recipientSessionId) + if (otherParty != null) { + if (message is SessionConfirm) { + logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" } + sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null) + } else { + logger.trace { "Ignoring session end message for already closed session: $message" } } + } else { + logger.warn("Received a session message for unknown session: $message") } } } + private fun onSessionInit(sessionInit: SessionInit) { + logger.trace { "Received $sessionInit" } + //TODO Verify the other party are who they say they are from the TLS subsystem + val otherParty = sessionInit.initiatorParty + val otherPartySessionId = sessionInit.initiatorSessionId + try { + val markerClass = Class.forName(sessionInit.protocolName) + val protocolFactory = serviceHub.getProtocolFactory(markerClass) + if (protocolFactory != null) { + val protocol = protocolFactory(otherParty) + val psm = createFiber(sessionInit.protocolName, protocol) + val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId) + openSessions[session.ourSessionId] = session + psm.openSessions[Pair(protocol, otherParty)] = session + updateCheckpoint(psm) + sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm) + psm.logger.debug { "Starting new ${psm.id} from $sessionInit on $session" } + startFiber(psm) + } else { + logger.warn("Unknown protocol marker class in $sessionInit") + sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null) + } + } catch (e: Exception) { + logger.warn("Received invalid $sessionInit", e) + sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null) + } + } + private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes> { - // We don't use the passed-in serializer here, because we need to use our own augmented Kryo. val kryo = quasarKryo() // add the map of tokens -> tokenizedServices to the kyro context SerializeAsTokenSerializer.setContext(kryo, serializationContext) @@ -204,7 +258,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val kryo = quasarKryo() // put the map of token -> tokenized into the kryo context SerializeAsTokenSerializer.setContext(kryo, serializationContext) - return serialisedFiber.deserialize(kryo) + return serialisedFiber.deserialize(kryo).apply { fromCheckpoint = true } } private fun quasarKryo(): Kryo { @@ -212,70 +266,51 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService return createKryo(serializer.kryo) } - private fun logError(e: Throwable, payload: Any?, topicSession: TopicSession?, psm: ProtocolStateMachineImpl<*>) { - psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " + - "when handling a message of type ${payload?.javaClass?.name} on queue $topicSession") - if (psm.logger.isTraceEnabled) { - val s = StringWriter() - Throwables.getRootCause(e).printStackTrace(PrintWriter(s)) - psm.logger.trace("Stack trace of protocol error is: $s") - } + private fun createFiber(loggerName: String, logic: ProtocolLogic): ProtocolStateMachineImpl { + val id = StateMachineRunId.createRandom() + return ProtocolStateMachineImpl(id, logic, scheduler, loggerName).apply { initFiber(this) } } - private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint { + private fun initFiber(psm: ProtocolStateMachineImpl<*>) { psm.database = database psm.serviceHub = serviceHub - psm.suspendAction = { request -> - psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" } - onNextSuspend(psm, request) + psm.actionOnSuspend = { ioRequest -> + updateCheckpoint(psm) + processIORequest(ioRequest) } psm.actionOnEnd = { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE mutex.locked { - val finalCheckpoint = stateMachines.remove(psm) - if (finalCheckpoint != null) { - checkpointStorage.removeCheckpoint(finalCheckpoint) - } + stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } totalFinishedProtocols.inc() notifyChangeObservers(psm, AddOrRemove.REMOVE) } + endAllFiberSessions(psm) } - val checkpoint = startingCheckpoint() - checkpoint.fiberCreated = true - totalStartedProtocols.inc() mutex.locked { - stateMachines[psm] = checkpoint + totalStartedProtocols.inc() notifyChangeObservers(psm, AddOrRemove.ADD) } - return checkpoint } - /** - * Kicks off a brand new state machine of the given class. It will log with the named logger. - * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is - * restarted with checkpointed state machines in the storage service. - */ - fun add(loggerName: String, logic: ProtocolLogic): ProtocolStateMachine { - val id = StateMachineRunId.createRandom() - val fiber = ProtocolStateMachineImpl(id, logic, scheduler, loggerName) - // Need to add before iterating in case of immediate completion - val checkpoint = initFiber(fiber) { - val checkpoint = Checkpoint(serializeFiber(fiber), null, null) - checkpoint - } - checkpointStorage.addCheckpoint(checkpoint) - mutex.locked { // If we are not started then our checkpoint will be picked up during start - if (!started) { - return fiber - } - } - - try { - executor.executeASAP { - iterateStateMachine(fiber, null) { - fiber.start() + private fun endAllFiberSessions(psm: ProtocolStateMachineImpl<*>) { + openSessions.values.removeIf { session -> + if (session.psm == psm) { + val otherPartySessionId = session.otherPartySessionId + if (otherPartySessionId != null) { + sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm) } + recentlyClosedSessions[session.ourSessionId] = session.otherParty + true + } else { + false } + } + } + + private fun startFiber(fiber: ProtocolStateMachineImpl<*>) { + try { + resumeFiber(fiber) } catch (e: ExecutionException) { // There are two ways we can take exceptions in this method: // @@ -290,17 +325,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService if (e.cause !is ExecutionException) throw e } + } + + /** + * Kicks off a brand new state machine of the given class. It will log with the named logger. + * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is + * restarted with checkpointed state machines in the storage service. + */ + fun add(loggerName: String, logic: ProtocolLogic): ProtocolStateMachine { + val fiber = createFiber(loggerName, logic) + updateCheckpoint(fiber) + // If we are not started then our checkpoint will be picked up during start + mutex.locked { + if (started) { + startFiber(fiber) + } + } return fiber } - private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, - serialisedFiber: SerializedBytes>, - request: ProtocolIORequest?, - receivedPayload: Any?) { - val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload) - val previousCheckpoint = mutex.locked { - stateMachines.put(psm, newCheckpoint) - } + private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>) { + check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } + val newCheckpoint = Checkpoint(serializeFiber(psm)) + val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) } if (previousCheckpoint != null) { checkpointStorage.removeCheckpoint(previousCheckpoint) } @@ -308,90 +355,70 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService checkpointingMeter.mark() } - private fun iterateStateMachine(psm: ProtocolStateMachineImpl<*>, - receivedPayload: Any?, - resumeAction: (Any?) -> Unit) { - executor.checkOnThread() - psm.receivedPayload = receivedPayload - psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" } - resumeAction(receivedPayload) - } - - private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) { - val serialisedFiber = serializeFiber(psm) - updateCheckpoint(psm, serialisedFiber, request, null) - // We have a request to do something: send, receive, or send-and-receive. - if (request is ReceiveRequest<*>) { - // Prepare a listener on the network that runs in the background thread when we receive a message. - prepareToReceiveForRequest(psm, serialisedFiber, request) - } - if (request is SendRequest) { - performSendRequest(psm, request) + private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) { + executor.executeASAP { + psm.resume(scheduler) } } - private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes>, request: ReceiveRequest<*>) { - executor.checkOnThread() - val queueID = request.receiveTopicSession - psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" } - iterateOnResponse(psm, serialisedFiber, request) { - try { - Fiber.unpark(psm, QUASAR_UNBLOCKER) - } catch(e: Throwable) { - logError(e, it, queueID, psm) + private fun processIORequest(ioRequest: ProtocolIORequest) { + if (ioRequest is SendRequest) { + if (ioRequest.message is SessionInit) { + openSessions[ioRequest.session.ourSessionId] = ioRequest.session + } + sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm) + if (ioRequest !is ReceiveRequest<*>) { + // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. + resumeFiber(ioRequest.session.psm) } } } - private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) { - val topicSession = sendMessage(psm, request) + private fun sendSessionMessage(party: Party, message: SessionMessage, psm: ProtocolStateMachineImpl<*>?) { + val node = serviceHub.networkMapCache.getNodeByLegalName(party.name) + ?: throw IllegalArgumentException("Don't know about party $party") + val logger = psm?.logger ?: logger + logger.trace { "${psm?.id} sending $message to party $party" } + serviceHub.networkService.send(sessionTopic, message, node.address) + } - if (request is SendOnly) { - // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - iterateStateMachine(psm, null) { - try { - Fiber.unpark(psm, QUASAR_UNBLOCKER) - } catch(e: Throwable) { - logError(e, request.payload, topicSession, psm) - } - } + + interface SessionMessage + + interface ExistingSessionMessage: SessionMessage { + val recipientSessionId: Long + } + + data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage + + interface SessionInitResponse : ExistingSessionMessage + + data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse { + override val recipientSessionId: Long get() = initiatorSessionId + } + + data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse { + override val recipientSessionId: Long get() = initiatorSessionId + } + + data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { + override fun toString(): String { + return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})" } } - private fun sendMessage(psm: ProtocolStateMachineImpl<*>, request: SendRequest): TopicSession { - val topicSession = TopicSession(request.topic, request.sendSessionID) - val payload = request.payload - psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" } - val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?: - throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems) - serviceHub.networkService.send(topicSession, payload, node.address, request.uniqueMessageId) - return topicSession + data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage + + + data class ProtocolSession(val protocol: ProtocolLogic<*>, + val otherParty: Party, + val ourSessionId: Long, + var otherPartySessionId: Long?, + @Volatile var waitingForResponse: Boolean = false) { + + val receivedMessages = ConcurrentLinkedQueue() + val psm: ProtocolStateMachineImpl<*> get() = protocol.psm as ProtocolStateMachineImpl<*> + } - /** - * Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is - * received. - */ - private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, - serialisedFiber: SerializedBytes>, - request: ReceiveRequest<*>, - resumeAction: (Any?) -> Unit) { - val topicSession = request.receiveTopicSession - serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg -> - // Assertion to ensure we don't execute on the wrong thread. - executor.checkOnThread() - // TODO: This is insecure: we should not deserialise whatever we find and *then* check. - // We should instead verify as we read the data that it's what we are expecting and throw as early as - // possible. We only do it this way for convenience during the prototyping stage. Note that this means - // we could simply not require the programmer to specify the expected return type at all, and catch it - // at the last moment when we do the downcast. However this would make protocol code harder to read and - // make it more difficult to migrate to a more explicit serialisation scheme later. - val payload = netMsg.data.deserialize() - check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" } - // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload - updateCheckpoint(psm, serialisedFiber, null, payload) - psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" } - iterateStateMachine(psm, payload, resumeAction) - } - } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/transactions/NotaryService.kt b/node/src/main/kotlin/com/r3corda/node/services/transactions/NotaryService.kt index 0780751dd9..65073d5a32 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/transactions/NotaryService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/transactions/NotaryService.kt @@ -1,12 +1,11 @@ package com.r3corda.node.services.transactions +import com.r3corda.core.crypto.Party import com.r3corda.core.node.services.ServiceType -import com.r3corda.core.node.services.TimestampChecker -import com.r3corda.core.node.services.UniquenessProvider -import com.r3corda.node.services.api.AbstractNodeService +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.protocols.NotaryProtocol -import com.r3corda.protocols.NotaryProtocol.TOPIC +import kotlin.reflect.KClass /** * A Notary service acts as the final signer of a transaction ensuring two things: @@ -17,22 +16,18 @@ import com.r3corda.protocols.NotaryProtocol.TOPIC * * This is the base implementation that can be customised with specific Notary transaction commit protocol. */ -abstract class NotaryService(services: ServiceHubInternal, - val timestampChecker: TimestampChecker, - val uniquenessProvider: UniquenessProvider) : AbstractNodeService(services) { +abstract class NotaryService(markerClass: KClass, services: ServiceHubInternal) : SingletonSerializeAsToken() { // Do not specify this as an advertised service. Use a concrete implementation. // TODO: We do not want a service type that cannot be used. Fix the type system abuse here. object Type : ServiceType("corda.notary") abstract val logger: org.slf4j.Logger - /** Implement a factory that specifies the transaction commit protocol for the notary service to use */ - abstract val protocolFactory: NotaryProtocol.Factory - init { - addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake -> - protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider) - } + services.registerProtocolInitiator(markerClass) { createProtocol(it) } } + /** Implement a factory that specifies the transaction commit protocol for the notary service to use */ + abstract fun createProtocol(otherParty: Party): NotaryProtocol.Service + } diff --git a/node/src/main/kotlin/com/r3corda/node/services/transactions/SimpleNotaryService.kt b/node/src/main/kotlin/com/r3corda/node/services/transactions/SimpleNotaryService.kt index a4689eb94b..3f108b0e42 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/transactions/SimpleNotaryService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/transactions/SimpleNotaryService.kt @@ -1,5 +1,6 @@ package com.r3corda.node.services.transactions +import com.r3corda.core.crypto.Party import com.r3corda.core.node.services.ServiceType import com.r3corda.core.node.services.TimestampChecker import com.r3corda.core.node.services.UniquenessProvider @@ -9,11 +10,13 @@ import com.r3corda.protocols.NotaryProtocol /** A simple Notary service that does not perform transaction validation */ class SimpleNotaryService(services: ServiceHubInternal, - timestampChecker: TimestampChecker, - uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) { + val timestampChecker: TimestampChecker, + val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.Client::class, services) { object Type : ServiceType("corda.notary.simple") override val logger = loggerFor() - override val protocolFactory = NotaryProtocol.DefaultFactory + override fun createProtocol(otherParty: Party): NotaryProtocol.Service { + return NotaryProtocol.Service(otherParty, timestampChecker, uniquenessProvider) + } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/transactions/ValidatingNotaryService.kt b/node/src/main/kotlin/com/r3corda/node/services/transactions/ValidatingNotaryService.kt index 0b1ea44b41..94d6c39ed3 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/transactions/ValidatingNotaryService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/transactions/ValidatingNotaryService.kt @@ -11,17 +11,13 @@ import com.r3corda.protocols.ValidatingNotaryProtocol /** A Notary service that validates the transaction chain of he submitted transaction before committing it */ class ValidatingNotaryService(services: ServiceHubInternal, - timestampChecker: TimestampChecker, - uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) { + val timestampChecker: TimestampChecker, + val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.ValidatingClient::class, services) { object Type : ServiceType("corda.notary.validating") override val logger = loggerFor() - override val protocolFactory = object : NotaryProtocol.Factory { - override fun create(otherSide: Party, - timestampChecker: TimestampChecker, - uniquenessProvider: UniquenessProvider): NotaryProtocol.Service { - return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider) - } + override fun createProtocol(otherParty: Party): ValidatingNotaryProtocol { + return ValidatingNotaryProtocol(otherParty, timestampChecker, uniquenessProvider) } } diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index 78823e890a..8d65900e6a 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -10,6 +10,7 @@ import com.r3corda.core.days import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.* +import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.WireTransaction @@ -23,7 +24,6 @@ import com.r3corda.node.services.persistence.PerFileTransactionStorage import com.r3corda.node.services.persistence.StorageServiceImpl import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Seller -import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC import com.r3corda.testing.* import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockNetwork @@ -89,11 +89,11 @@ class TwoPartyTradeProtocolTests { insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey) - val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) + val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) // TODO: Verify that the result was inserted into the transaction database. // assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id]) - assertEquals(aliceResult.get(), bobResult.get()) + assertEquals(aliceResult.get(), bobPsm.get().resultFuture.get()) aliceNode.stop() bobNode.stop() @@ -120,21 +120,19 @@ class TwoPartyTradeProtocolTests { 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) - val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerFuture + val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerResult // Everything is on this thread so we can now step through the protocol one step at a time. // Seller Alice already sent a message to Buyer Bob. Pump once: - fun pumpAlice() = (aliceNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false) - - fun pumpBob() = (bobNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false) - - pumpBob() + bobNode.pumpReceive(false) // Bob sends a couple of queries for the dependencies back to Alice. Alice reponds. - pumpAlice() - pumpBob() - pumpAlice() - pumpBob() + aliceNode.pumpReceive(false) + bobNode.pumpReceive(false) + aliceNode.pumpReceive(false) + bobNode.pumpReceive(false) + aliceNode.pumpReceive(false) + bobNode.pumpReceive(false) // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1) @@ -147,7 +145,7 @@ class TwoPartyTradeProtocolTests { // Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use. // She will wait around until Bob comes back. - assertThat(pumpAlice()).isNotNull() + assertThat(aliceNode.pumpReceive(false)).isNotNull() // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // that Bob was waiting on before the reboot occurred. @@ -309,16 +307,16 @@ class TwoPartyTradeProtocolTests { val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray())) val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second - val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode.services) + insertFakeTransactions(bobsFakeCash, bobNode.services) val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second - val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) + insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey) net.runNetwork() // Clear network map registration messages val aliceTxStream = aliceNode.storage.validatedTransactions.track().second val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second - val (bobResult, aliceResult, bobSmId, aliceSmId) = runBuyerAndSeller("alice's paper".outputStateAndRef()) + val aliceSmId = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerId net.runNetwork() @@ -367,21 +365,20 @@ class TwoPartyTradeProtocolTests { } } - data class RunResult( - val buyerFuture: Future, - val sellerFuture: Future, - val buyerSmId: StateMachineRunId, - val sellerSmId: StateMachineRunId + private data class RunResult( + // The buyer is not created immediately, only when the seller starts running + val buyer: Future>, + val sellerResult: Future, + val sellerId: StateMachineRunId ) - private fun runBuyerAndSeller(assetToSell: StateAndRef): RunResult { - val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java) + private fun runBuyerAndSeller(assetToSell: StateAndRef) : RunResult { + val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty -> + Buyer(otherParty, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java) + } val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY) - connectProtocols(buyer, seller) - // We start the Buyer first, as the Seller sends the first message - val buyerPsm = bobNode.smm.add("$TOPIC.buyer", buyer) - val sellerPsm = aliceNode.smm.add("$TOPIC.seller", seller) - return RunResult(buyerPsm.resultFuture, sellerPsm.resultFuture, buyerPsm.id, sellerPsm.id) + val sellerResultFuture = aliceNode.smm.add("seller", seller).resultFuture + return RunResult(buyerFuture, sellerResultFuture, seller.psm.id) } private fun LedgerDSL.runWithError( @@ -404,7 +401,7 @@ class TwoPartyTradeProtocolTests { net.runNetwork() // Clear network map registration messages - val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) + val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef()) net.runNetwork() @@ -412,7 +409,7 @@ class TwoPartyTradeProtocolTests { if (bobError) aliceResult.get() else - bobResult.get() + bobPsm.get().resultFuture.get() } assertTrue(e.cause is TransactionVerificationException) assertNotNull(e.cause!!.cause) @@ -506,6 +503,7 @@ class TwoPartyTradeProtocolTests { return Pair(vault, listOf(ap)) } + class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage { override fun track(): Pair, Observable> { return delegate.track() @@ -530,4 +528,5 @@ class TwoPartyTradeProtocolTests { data class Add(val transaction: SignedTransaction) : TxRecord data class Get(val id: SecureHash) : TxRecord } + } diff --git a/node/src/test/kotlin/com/r3corda/node/services/MockServiceHubInternal.kt b/node/src/test/kotlin/com/r3corda/node/services/MockServiceHubInternal.kt index 509d8fe971..657d23a79b 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/MockServiceHubInternal.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/MockServiceHubInternal.kt @@ -2,23 +2,24 @@ package com.r3corda.node.services import com.codahale.metrics.MetricRegistry import com.google.common.util.concurrent.ListenableFuture -import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.crypto.Party import com.r3corda.core.node.services.* import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogicRefFactory -import com.r3corda.core.protocols.StateMachineRunId +import com.r3corda.core.testing.InMemoryVaultService +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.node.serialization.NodeClock import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.api.MonitoringService import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.testing.node.MockNetworkMapCache -import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.persistence.DataVending import com.r3corda.node.services.statemachine.StateMachineManager -import com.r3corda.core.testing.InMemoryVaultService -import com.r3corda.testing.node.MockStorageService import com.r3corda.testing.MOCK_IDENTITY_SERVICE +import com.r3corda.testing.node.MockNetworkMapCache +import com.r3corda.testing.node.MockStorageService import java.time.Clock +import java.util.concurrent.ConcurrentHashMap +import kotlin.reflect.KClass @Suppress("LeakingThis") open class MockServiceHubInternal( @@ -28,7 +29,6 @@ open class MockServiceHubInternal( val identity: IdentityService? = MOCK_IDENTITY_SERVICE, val storage: TxWritableStorageService? = MockStorageService(), val mapCache: NetworkMapCache? = MockNetworkMapCache(), - val mapService: NetworkMapService? = null, val scheduler: SchedulerService? = null, val overrideClock: Clock? = NodeClock(), val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory() @@ -57,14 +57,10 @@ open class MockServiceHubInternal( private val txStorageService: TxWritableStorageService get() = storage ?: throw UnsupportedOperationException() - override fun recordTransactions(txs: Iterable) = recordTransactionsInternal(txStorageService, txs) + private val protocolFactories = ConcurrentHashMap, (Party) -> ProtocolLogic<*>>() lateinit var smm: StateMachineManager - override fun startProtocol(loggerName: String, logic: ProtocolLogic): ListenableFuture { - return smm.add(loggerName, logic).resultFuture - } - init { if (net != null && storage != null) { // Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener @@ -72,4 +68,18 @@ open class MockServiceHubInternal( DataVending.Service(this) } } + + override fun recordTransactions(txs: Iterable) = recordTransactionsInternal(txStorageService, txs) + + override fun startProtocol(loggerName: String, logic: ProtocolLogic): ListenableFuture { + return smm.add(loggerName, logic).resultFuture + } + + override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) { + protocolFactories[markerClass.java] = protocolFactory + } + + override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? { + return protocolFactories[markerClass] + } } diff --git a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt index 892cf91b98..3fdfb28195 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt @@ -3,7 +3,6 @@ package com.r3corda.node.services import com.google.common.jimfs.Configuration import com.google.common.jimfs.Jimfs import com.r3corda.core.contracts.* -import com.r3corda.core.crypto.SecureHash import com.r3corda.core.days import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.recordTransactions @@ -12,18 +11,16 @@ import com.r3corda.core.protocols.ProtocolLogicRef import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.utilities.DUMMY_NOTARY -import com.r3corda.core.utilities.LogHelper -import com.r3corda.testing.node.TestClock import com.r3corda.node.services.events.NodeSchedulerService -import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.statemachine.StateMachineManager -import com.r3corda.node.services.vault.NodeVaultService import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.configureDatabase import com.r3corda.testing.ALICE_KEY +import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockKeyManagementService +import com.r3corda.testing.node.TestClock import com.r3corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.junit.After @@ -34,7 +31,9 @@ import java.nio.file.FileSystem import java.security.PublicKey import java.time.Clock import java.time.Instant -import java.util.concurrent.* +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit import kotlin.test.assertTrue class NodeSchedulerServiceTest : SingletonSerializeAsToken() { @@ -128,8 +127,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { (serviceHub as TestReference).testReference.calls += increment (serviceHub as TestReference).testReference.countDown.countDown() } - - override val topic: String get() = throw UnsupportedOperationException() } class Command : TypeOnlyCommandData() diff --git a/node/src/test/kotlin/com/r3corda/node/services/NotaryChangeTests.kt b/node/src/test/kotlin/com/r3corda/node/services/NotaryChangeTests.kt index 5d7477edcc..7d31ed4f3e 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NotaryChangeTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NotaryChangeTests.kt @@ -9,7 +9,6 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.node.internal.AbstractNode import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.transactions.SimpleNotaryService -import com.r3corda.protocols.NotaryChangeProtocol import com.r3corda.protocols.NotaryChangeProtocol.Instigator import com.r3corda.protocols.StateReplacementException import com.r3corda.protocols.StateReplacementRefused @@ -49,7 +48,7 @@ class NotaryChangeTests { val state = issueState(clientNodeA) val newNotary = newNotaryNode.info.identity val protocol = Instigator(state, newNotary) - val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) + val future = clientNodeA.services.startProtocol("notary-change", protocol) net.runNetwork() @@ -62,7 +61,7 @@ class NotaryChangeTests { val state = issueMultiPartyState(clientNodeA, clientNodeB) val newNotary = newNotaryNode.info.identity val protocol = Instigator(state, newNotary) - val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) + val future = clientNodeA.services.startProtocol("notary-change", protocol) net.runNetwork() @@ -78,7 +77,7 @@ class NotaryChangeTests { val state = issueMultiPartyState(clientNodeA, clientNodeB) val newEvilNotary = Party("Evil Notary", generateKeyPair().public) val protocol = Instigator(state, newEvilNotary) - val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol) + val future = clientNodeA.services.startProtocol("notary-change", protocol) net.runNetwork() diff --git a/node/src/test/kotlin/com/r3corda/node/services/NotaryServiceTests.kt b/node/src/test/kotlin/com/r3corda/node/services/NotaryServiceTests.kt index de37f035b1..16951ec7b9 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NotaryServiceTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NotaryServiceTests.kt @@ -1,17 +1,19 @@ package com.r3corda.node.services -import com.r3corda.core.contracts.Timestamp +import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.contracts.TransactionType +import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.seconds +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY -import com.r3corda.testing.node.MockNetwork import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.protocols.NotaryError import com.r3corda.protocols.NotaryException import com.r3corda.protocols.NotaryProtocol import com.r3corda.testing.MINI_CORP_KEY +import com.r3corda.testing.node.MockNetwork import org.junit.Before import org.junit.Test import java.time.Instant @@ -45,10 +47,7 @@ class NotaryServiceTests { tx.toSignedTransaction(false) } - val protocol = NotaryProtocol.Client(stx) - val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol) - net.runNetwork() - + val future = runNotaryClient(stx) val signature = future.get() signature.verifyWithECDSA(stx.txBits) } @@ -61,10 +60,7 @@ class NotaryServiceTests { tx.toSignedTransaction(false) } - val protocol = NotaryProtocol.Client(stx) - val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol) - net.runNetwork() - + val future = runNotaryClient(stx) val signature = future.get() signature.verifyWithECDSA(stx.txBits) } @@ -78,16 +74,13 @@ class NotaryServiceTests { tx.toSignedTransaction(false) } - val protocol = NotaryProtocol.Client(stx) - val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol) - net.runNetwork() + val future = runNotaryClient(stx) val ex = assertFailsWith(ExecutionException::class) { future.get() } val error = (ex.cause as NotaryException).error assertTrue(error is NotaryError.TimestampInvalid) } - @Test fun `should report conflict for a duplicate transaction`() { val stx = run { val inputState = issueState(clientNode) @@ -98,8 +91,8 @@ class NotaryServiceTests { val firstSpend = NotaryProtocol.Client(stx) val secondSpend = NotaryProtocol.Client(stx) - clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.first", firstSpend) - val future = clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.second", secondSpend) + clientNode.services.startProtocol("notary.first", firstSpend) + val future = clientNode.services.startProtocol("notary.second", secondSpend) net.runNetwork() @@ -108,4 +101,12 @@ class NotaryServiceTests { assertEquals(notaryError.tx, stx.tx) notaryError.conflict.verified() } + + + private fun runNotaryClient(stx: SignedTransaction): ListenableFuture { + val protocol = NotaryProtocol.Client(stx) + val future = clientNode.services.startProtocol("notary-test", protocol) + net.runNetwork() + return future + } } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/ValidatingNotaryServiceTests.kt b/node/src/test/kotlin/com/r3corda/node/services/ValidatingNotaryServiceTests.kt index 752cd4a778..95ea4a41da 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/ValidatingNotaryServiceTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/ValidatingNotaryServiceTests.kt @@ -1,8 +1,11 @@ package com.r3corda.node.services +import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.contracts.Command import com.r3corda.core.contracts.DummyContract import com.r3corda.core.contracts.TransactionType +import com.r3corda.core.crypto.DigitalSignature +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.node.services.network.NetworkMapService @@ -44,9 +47,7 @@ class ValidatingNotaryServiceTests { tx.toSignedTransaction(false) } - val protocol = NotaryProtocol.Client(stx) - val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol) - net.runNetwork() + val future = runValidatingClient(stx) val ex = assertFailsWith(ExecutionException::class) { future.get() } val notaryError = (ex.cause as NotaryException).error @@ -64,9 +65,7 @@ class ValidatingNotaryServiceTests { tx.toSignedTransaction(false) } - val protocol = NotaryProtocol.Client(stx) - val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol) - net.runNetwork() + val future = runValidatingClient(stx) val ex = assertFailsWith(ExecutionException::class) { future.get() } val notaryError = (ex.cause as NotaryException).error @@ -75,4 +74,11 @@ class ValidatingNotaryServiceTests { val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners assertEquals(setOf(expectedMissingKey), missingKeys) } + + private fun runValidatingClient(stx: SignedTransaction): ListenableFuture { + val protocol = NotaryProtocol.ValidatingClient(stx) + val future = clientNode.services.startProtocol("notary", protocol) + net.runNetwork() + return future + } } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/DataVendingServiceTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/DataVendingServiceTests.kt index 487a930931..3a5c32c58a 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/DataVendingServiceTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/DataVendingServiceTests.kt @@ -1,13 +1,20 @@ package com.r3corda.node.services.persistence +import co.paralleluniverse.fibers.Suspendable import com.r3corda.contracts.asset.Cash import com.r3corda.core.contracts.Amount import com.r3corda.core.contracts.Issued import com.r3corda.core.contracts.TransactionType import com.r3corda.core.contracts.USD +import com.r3corda.core.crypto.Party +import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.DUMMY_NOTARY -import com.r3corda.testing.node.MockNetwork +import com.r3corda.node.services.persistence.DataVending.Service.NotifyTransactionHandler +import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest import com.r3corda.testing.MEGA_CORP +import com.r3corda.testing.node.MockNetwork +import com.r3corda.testing.node.MockNetwork.MockNode import org.junit.Before import org.junit.Test import kotlin.test.assertEquals @@ -38,9 +45,8 @@ class DataVendingServiceTests { ptx.signWith(registerNode.services.storageService.myLegalIdentityKey) val tx = ptx.toSignedTransaction() assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) - DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity, - vaultServiceNode.info, tx) - network.runNetwork() + + registerNode.sendNotifyTx(tx, vaultServiceNode) // Check the transaction is in the receiving node val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull() @@ -67,11 +73,23 @@ class DataVendingServiceTests { ptx.signWith(registerNode.services.storageService.myLegalIdentityKey) val tx = ptx.toSignedTransaction(false) assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) - DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity, - vaultServiceNode.info, tx) - network.runNetwork() + + registerNode.sendNotifyTx(tx, vaultServiceNode) // Check the transaction is not in the receiving node assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size) } -} \ No newline at end of file + + private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) { + walletServiceNode.services.registerProtocolInitiator(NotifyTxProtocol::class, ::NotifyTransactionHandler) + services.startProtocol("notify-tx", NotifyTxProtocol(walletServiceNode.info.identity, tx)) + network.runNetwork() + } + + + private class NotifyTxProtocol(val otherParty: Party, val stx: SignedTransaction) : ProtocolLogic() { + @Suspendable + override fun call() = send(otherParty, NotifyTxRequest(stx, emptySet())) + } + +} diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index 9c76e8d7c4..e15046bfc1 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -10,12 +10,14 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After import org.junit.Before import org.junit.Test +import java.nio.file.FileSystem import java.nio.file.Files +import java.nio.file.Path class PerFileCheckpointStorageTests { - val fileSystem = Jimfs.newFileSystem(unix()) - val storeDir = fileSystem.getPath("store") + val fileSystem: FileSystem = Jimfs.newFileSystem(unix()) + val storeDir: Path = fileSystem.getPath("store") lateinit var checkpointStorage: PerFileCheckpointStorage @Before @@ -92,6 +94,6 @@ class PerFileCheckpointStorageTests { } private var checkpointCount = 1 - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null) + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++))) } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index f0d00c2ce1..b2796b9153 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -2,14 +2,20 @@ package com.r3corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable +import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.protocols.ProtocolSessionException import com.r3corda.core.random63BitValue -import com.r3corda.testing.connectProtocols +import com.r3corda.core.serialization.deserialize +import com.r3corda.node.services.statemachine.StateMachineManager.SessionData +import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage +import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After import org.junit.Before import org.junit.Test @@ -50,18 +56,18 @@ class StateMachineManagerTests { } @Test - fun `protocol suspended just after receiving payload`() { - val topic = "send-and-receive" + fun `protocol restarted just after receiving payload`() { + node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) } val payload = random63BitValue() - val sendProtocol = SendProtocol(topic, node2.info.identity, payload) - val receiveProtocol = ReceiveProtocol(topic, node1.info.identity) - connectProtocols(sendProtocol, receiveProtocol) - node1.smm.add("test", sendProtocol) - node2.smm.add("test", receiveProtocol) - net.runNetwork() + node1.smm.add("test", SendProtocol(payload, node2.info.identity)) + + // We push through just enough messages to get only the SessionData sent + // TODO We should be able to give runNetwork a predicate for when to stop + net.runNetwork(2) node2.stop() - val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) - assertThat(restoredProtocol.receivedPayload).isEqualTo(payload) + net.runNetwork() + val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) + assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) } @Test @@ -83,7 +89,7 @@ class StateMachineManagerTests { node3.stop() node3 = net.createNode(node1.info.address, forcedID = node3.id) - val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first + val restoredProtocol = node3.getSingleProtocol().first assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet net.runNetwork() // Allow network map messages to flow node3.smm.executor.flush() @@ -99,43 +105,44 @@ class StateMachineManagerTests { @Test fun `protocol loaded from checkpoint will respond to messages from before start`() { - val topic = "send-and-receive" val payload = random63BitValue() - val sendProtocol = SendProtocol(topic, node2.info.identity, payload) - val receiveProtocol = ReceiveProtocol(topic, node1.info.identity) - connectProtocols(sendProtocol, receiveProtocol) + node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) } + val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.identity) node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol node2.stop() // kill receiver - node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService - val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) - assertThat(restoredProtocol.receivedPayload).isEqualTo(payload) + val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info.address) + assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) } @Test fun `protocol with send will resend on interrupted restart`() { - val topic = "send-and-receive" val payload = random63BitValue() val payload2 = random63BitValue() + var sentCount = 0 var receivedCount = 0 - net.messagingNetwork.sentMessages.subscribe { if (it.message.topicSession.topic == topic) sentCount++ } - net.messagingNetwork.receivedMessages.subscribe { if (it.message.topicSession.topic == topic) receivedCount++ } + net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ } + net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ } val node3 = net.createNode(node1.info.address) net.runNetwork() - val firstProtocol = PingPongProtocol(topic, node3.info.identity, payload) - val secondProtocol = PingPongProtocol(topic, node2.info.identity, payload2) - connectProtocols(firstProtocol, secondProtocol) + + var secondProtocol: PingPongProtocol? = null + node3.services.registerProtocolInitiator(PingPongProtocol::class) { + val protocol = PingPongProtocol(it, payload2) + secondProtocol = protocol + protocol + } + // Kick off first send and receive - node2.smm.add("test", firstProtocol) + node2.smm.add("test", PingPongProtocol(node3.info.identity, payload)) assertEquals(1, node2.checkpointStorage.checkpoints.count()) // Restart node and thus reload the checkpoint and resend the message with same UUID node2.stop() val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) - val (firstAgain, fut1) = node2b.smm.findStateMachines(PingPongProtocol::class.java).single() + val (firstAgain, fut1) = node2b.getSingleProtocol() net.runNetwork() assertEquals(1, node2.checkpointStorage.checkpoints.count()) - // Now add in the other half of the protocol. First message should get deduped. So message data stays in sync. - node3.smm.add("test", secondProtocol) + // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. net.runNetwork() node2b.smm.executor.flush() fut1.get() @@ -146,15 +153,66 @@ class StateMachineManagerTests { assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") - assertEquals(payload, secondProtocol.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") - assertEquals(payload + 1, secondProtocol.receivedPayload2, "Received payload does not match the expected second value on Node 2") + assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") + assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2") + } + + @Test + fun `sending to multiple parties`() { + val node3 = net.createNode(node1.info.address) + net.runNetwork() + node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) } + node3.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) } + val payload = random63BitValue() + node1.smm.add("multiple-send", SendProtocol(payload, node2.info.identity, node3.info.identity)) + net.runNetwork() + val node2Protocol = node2.getSingleProtocol().first + val node3Protocol = node3.getSingleProtocol().first + assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload) + assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload) + } + + @Test + fun `receiving from multiple parties`() { + val node3 = net.createNode(node1.info.address) + net.runNetwork() + val node2Payload = random63BitValue() + val node3Payload = random63BitValue() + node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node2Payload, it) } + node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } + val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.identity, node3.info.identity) + node1.smm.add("multiple-receive", multiReceiveProtocol) + net.runNetwork(1) // session handshaking + // have the messages arrive in reverse order of receive + node3.pumpReceive(false) + node2.pumpReceive(false) + net.runNetwork() // pump remaining messages + assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) + assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) + } + + @Test + fun `exception thrown on other side`() { + node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { ExceptionProtocol } + val future = node1.smm.add("exception", ReceiveThenSuspendProtocol(node2.info.identity)).resultFuture + net.runNetwork() + assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java) + } + + private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { + return transfer.message.topicSession == StateMachineManager.sessionTopic + && transfer.message.data.deserialize() is SessionData } private inline fun MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P { - val servicesArray = advertisedServices.toTypedArray() - val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray) + stop() + val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray()) mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine - return node.smm.findStateMachines(P::class.java).single().first + return newNode.getSingleProtocol

().first + } + + private inline fun > MockNode.getSingleProtocol(): Pair> { + return smm.findStateMachines(P::class.java).single() } @@ -165,8 +223,6 @@ class StateMachineManagerTests { override fun call() { protocolStarted = true } - - override val topic: String get() = throw UnsupportedOperationException() } private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() { @@ -177,8 +233,6 @@ class StateMachineManagerTests { override fun doCall() { protocolStarted = true } - - override val topic: String get() = throw UnsupportedOperationException() } @@ -187,30 +241,37 @@ class StateMachineManagerTests { val lazyTime by lazy { serviceHub.clock.instant() } @Suspendable - override fun call() { + override fun call() = Unit + } + + + private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic() { + + init { + require(otherParties.isNotEmpty()) } - override val topic: String get() = throw UnsupportedOperationException() - } - - - private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic() { @Suspendable - override fun call() = send(otherParty, payload) + override fun call() = otherParties.forEach { send(it, payload) } } - private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() { + private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() { - @Transient var receivedPayload: Any? = null + init { + require(otherParties.isNotEmpty()) + } + + @Transient var receivedPayloads: List = emptyList() @Suspendable override fun doCall() { - receivedPayload = receive(otherParty).unwrap { it } + receivedPayloads = otherParties.map { receive(it).unwrap { it } } } } - private class PingPongProtocol(override val topic: String, val otherParty: Party, val payload: Long) : ProtocolLogic() { + private class PingPongProtocol(val otherParty: Party, val payload: Long) : ProtocolLogic() { + @Transient var receivedPayload: Long? = null @Transient var receivedPayload2: Long? = null @@ -219,7 +280,10 @@ class StateMachineManagerTests { receivedPayload = sendAndReceive(otherParty, payload).unwrap { it } receivedPayload2 = sendAndReceive(otherParty, (payload + 1)).unwrap { it } } + } + private object ExceptionProtocol : ProtocolLogic() { + override fun call(): Nothing = throw Exception() } /** diff --git a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt index 85f7df9927..b8e3058c03 100644 --- a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt +++ b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt @@ -7,20 +7,22 @@ import com.r3corda.contracts.CommercialPaper import com.r3corda.contracts.asset.DUMMY_CASH_ISSUER import com.r3corda.contracts.asset.cashBalances import com.r3corda.contracts.testing.fillWithSomeTestCash -import com.r3corda.core.* import com.r3corda.core.contracts.* import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.generateKeyPair +import com.r3corda.core.days +import com.r3corda.core.logElapsedTime import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.services.ServiceType import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.seconds +import com.r3corda.core.success import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.Emoji import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.ProgressTracker import com.r3corda.node.internal.Node -import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfigurationFromConfig import com.r3corda.node.services.messaging.NodeMessagingClient @@ -28,7 +30,6 @@ import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.utilities.databaseTransaction -import com.r3corda.protocols.HandshakeMessage import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.TwoPartyTradeProtocol import joptsimple.OptionParser @@ -210,23 +211,13 @@ private fun runBuyer(node: Node, amount: Amount) { // next stage in our building site, we will just auto-generate fake trades to give our nodes something to do. // // As the seller initiates the two-party trade protocol, here, we will be the buyer. - object : AbstractNodeService(node.services) { - init { - addProtocolHandler(DEMO_TOPIC, "demo.buyer") { handshake: TraderDemoHandshake -> - TraderDemoProtocolBuyer(handshake.replyToParty, attachmentsPath, amount) - } - } + node.services.registerProtocolInitiator(TraderDemoProtocolSeller::class) { otherParty -> + TraderDemoProtocolBuyer(otherParty, attachmentsPath, amount) } } // We create a couple of ad-hoc test protocols that wrap the two party trade protocol, to give us the demo logic. -private data class TraderDemoHandshake(override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage - -private val DEMO_TOPIC = "initiate.demo.trade" - private class TraderDemoProtocolBuyer(val otherSide: Party, private val attachmentsPath: Path, val amount: Amount, @@ -234,8 +225,6 @@ private class TraderDemoProtocolBuyer(val otherSide: Party, object STARTING_BUY : ProgressTracker.Step("Seller connected, purchasing commercial paper asset") - override val topic: String get() = DEMO_TOPIC - @Suspendable override fun call() { progressTracker.currentStep = STARTING_BUY @@ -248,7 +237,7 @@ private class TraderDemoProtocolBuyer(val otherSide: Party, CommercialPaper.State::class.java) // This invokes the trading protocol and out pops our finished transaction. - val tradeTX: SignedTransaction = subProtocol(buyer, inheritParentSessions = true) + val tradeTX: SignedTransaction = subProtocol(buyer, shareParentSessions = true) // TODO: This should be moved into the protocol itself. serviceHub.recordTransactions(listOf(tradeTX)) @@ -289,8 +278,6 @@ private class TraderDemoProtocolSeller(val otherSide: Party, companion object { val PROSPECTUS_HASH = SecureHash.parse("decd098666b9657314870e192ced0c3519c2c9d395507a238338f8d003929de9") - object ANNOUNCING : ProgressTracker.Step("Announcing to the buyer node") - object SELF_ISSUING : ProgressTracker.Step("Got session ID back, issuing and timestamping some commercial paper") object TRADING : ProgressTracker.Step("Starting the trade protocol") { @@ -300,17 +287,11 @@ private class TraderDemoProtocolSeller(val otherSide: Party, // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // surprised when it appears as a new set of tasks below the current one. - fun tracker() = ProgressTracker(ANNOUNCING, SELF_ISSUING, TRADING) + fun tracker() = ProgressTracker(SELF_ISSUING, TRADING) } - override val topic: String get() = DEMO_TOPIC - @Suspendable override fun call(): SignedTransaction { - progressTracker.currentStep = ANNOUNCING - - send(otherSide, TraderDemoHandshake(serviceHub.storageService.myLegalIdentity)) - progressTracker.currentStep = SELF_ISSUING val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0] @@ -326,7 +307,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party, amount, cpOwnerKey, progressTracker.getChildProgressTracker(TRADING)!!) - val tradeTX: SignedTransaction = subProtocol(seller, inheritParentSessions = true) + val tradeTX: SignedTransaction = subProtocol(seller, shareParentSessions = true) serviceHub.recordTransactions(listOf(tradeTX)) return tradeTX diff --git a/src/main/kotlin/com/r3corda/demos/api/NodeInterestRates.kt b/src/main/kotlin/com/r3corda/demos/api/NodeInterestRates.kt index 7bd1b16156..3c250e24fd 100644 --- a/src/main/kotlin/com/r3corda/demos/api/NodeInterestRates.kt +++ b/src/main/kotlin/com/r3corda/demos/api/NodeInterestRates.kt @@ -12,16 +12,14 @@ import com.r3corda.core.math.InterpolatorFactory import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.services.ServiceType import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.utilities.ProgressTracker -import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.api.AcceptsFileUpload import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.utilities.FiberBox -import com.r3corda.protocols.RatesFixProtocol -import com.r3corda.protocols.ServiceRequestMessage +import com.r3corda.protocols.RatesFixProtocol.* import com.r3corda.protocols.TwoPartyDealProtocol -import org.slf4j.LoggerFactory import java.io.InputStream import java.math.BigDecimal import java.security.KeyPair @@ -55,46 +53,31 @@ object NodeInterestRates { /** * The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing. */ - class Service(services: ServiceHubInternal) : AcceptsFileUpload, AbstractNodeService(services) { + class Service(services: ServiceHubInternal) : AcceptsFileUpload, SingletonSerializeAsToken() { val ss = services.storageService val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey, services.clock) - private val logger = LoggerFactory.getLogger(Service::class.java) - init { - addMessageHandler(RatesFixProtocol.TOPIC, - { req: ServiceRequestMessage -> - if (req is RatesFixProtocol.SignRequest) { - oracle.sign(req.tx) - } - else { - /** - * We put this into a protocol so that if it blocks waiting for the interest rate to become - * available, we a) don't block this thread and b) allow the fact we are waiting - * to be persisted/checkpointed. - * Interest rates become available when they are uploaded via the web as per [DataUploadServlet], - * if they haven't already been uploaded that way. - */ - req as RatesFixProtocol.QueryRequest - val handler = FixQueryHandler(this, req) - handler.registerSession(req) - services.startProtocol("fixing", handler) - Unit - } - }, - { message, e -> logger.error("Exception during interest rate oracle request processing", e) } - ) + services.registerProtocolInitiator(FixSignProtocol::class) { FixSignHandler(it, oracle) } + services.registerProtocolInitiator(FixQueryProtocol::class) { FixQueryHandler(it, oracle) } } - private class FixQueryHandler(val service: Service, - val request: RatesFixProtocol.QueryRequest) : ProtocolLogic() { + + private class FixSignHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic() { + @Suspendable + override fun call() { + val request = receive(otherParty).unwrap { it } + send(otherParty, oracle.sign(request.tx)) + } + } + + private class FixQueryHandler(val otherParty: Party, val oracle: Oracle) : ProtocolLogic() { companion object { object RECEIVED : ProgressTracker.Step("Received fix request") object SENDING : ProgressTracker.Step("Sending fix response") } - override val topic: String get() = RatesFixProtocol.TOPIC override val progressTracker = ProgressTracker(RECEIVED, SENDING) init { @@ -103,9 +86,10 @@ object NodeInterestRates { @Suspendable override fun call(): Unit { - val answers = service.oracle.query(request.queries, request.deadline) + val request = receive(otherParty).unwrap { it } + val answers = oracle.query(request.queries, request.deadline) progressTracker.currentStep = SENDING - send(request.replyToParty, answers) + send(otherParty, answers) } } diff --git a/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt index bd50522118..c9674a7de3 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/AutoOfferProtocol.kt @@ -1,20 +1,17 @@ package com.r3corda.demos.protocols import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.FutureCallback import com.r3corda.core.contracts.DealState import com.r3corda.core.crypto.Party import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue +import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.ProgressTracker -import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.protocols.HandshakeMessage import com.r3corda.protocols.TwoPartyDealProtocol import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor -import com.r3corda.protocols.TwoPartyDealProtocol.DEAL_TOPIC +import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer import com.r3corda.protocols.TwoPartyDealProtocol.Instigator /** @@ -25,58 +22,27 @@ import com.r3corda.protocols.TwoPartyDealProtocol.Instigator * or the protocol would have to reach out to external systems (or users) to verify the deals. */ object AutoOfferProtocol { - val TOPIC = "autooffer.topic" - data class AutoOfferMessage(val notary: Party, - val dealBeingOffered: DealState, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage - - class Plugin: CordaPluginRegistry() { + class Plugin : CordaPluginRegistry() { override val servicePlugins: List> = listOf(Service::class.java) } - class Service(services: ServiceHubInternal) : AbstractNodeService(services) { + class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() { object DEALING : ProgressTracker.Step("Starting the deal protocol") { override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Secondary.tracker() } - fun tracker() = ProgressTracker(DEALING) - - class Callback(val success: (SignedTransaction) -> Unit) : FutureCallback { - override fun onFailure(t: Throwable?) { - // TODO handle exceptions - } - - override fun onSuccess(st: SignedTransaction?) { - success(st!!) - } - } - init { - addProtocolHandler(TOPIC, "$DEAL_TOPIC.seller") { autoOfferMessage: AutoOfferMessage -> - val progressTracker = tracker() - // Put the deal onto the ledger - progressTracker.currentStep = DEALING - Acceptor( - autoOfferMessage.replyToParty, - autoOfferMessage.notary, - autoOfferMessage.dealBeingOffered, - progressTracker.getChildProgressTracker(DEALING)!! - ) - } + services.registerProtocolInitiator(Instigator::class) { Acceptor(it) } } - } class Requester(val dealToBeOffered: DealState) : ProtocolLogic() { companion object { object RECEIVED : ProgressTracker.Step("Received API call") - object ANNOUNCING : ProgressTracker.Step("Announcing to the peer node") object DEALING : ProgressTracker.Step("Starting the deal protocol") { override fun childProgressTracker(): ProgressTracker = TwoPartyDealProtocol.Primary.tracker() } @@ -84,10 +50,9 @@ object AutoOfferProtocol { // We vend a progress tracker that already knows there's going to be a TwoPartyTradingProtocol involved at some // point: by setting up the tracker in advance, the user can see what's coming in more detail, instead of being // surprised when it appears as a new set of tasks below the current one. - fun tracker() = ProgressTracker(RECEIVED, ANNOUNCING, DEALING) + fun tracker() = ProgressTracker(RECEIVED, DEALING) } - override val topic: String get() = TOPIC override val progressTracker = tracker() init { @@ -100,17 +65,14 @@ object AutoOfferProtocol { val notary = serviceHub.networkMapCache.notaryNodes.first().identity // need to pick which ever party is not us val otherParty = notUs(dealToBeOffered.parties).single() - progressTracker.currentStep = ANNOUNCING - send(otherParty, AutoOfferMessage(notary, dealToBeOffered, serviceHub.storageService.myLegalIdentity)) progressTracker.currentStep = DEALING val instigator = Instigator( otherParty, - notary, - dealToBeOffered, + AutoOffer(notary, dealToBeOffered), serviceHub.storageService.myLegalIdentityKey, progressTracker.getChildProgressTracker(DEALING)!! ) - val stx = subProtocol(instigator, inheritParentSessions = true) + val stx = subProtocol(instigator) return stx } diff --git a/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt index 1f56d98820..8a8b750f9a 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt @@ -5,56 +5,49 @@ import co.paralleluniverse.strands.Strand import com.r3corda.core.crypto.Party import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.NodeInfo -import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue -import com.r3corda.core.serialization.deserialize import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.protocols.HandshakeMessage import com.r3corda.testing.node.MockNetworkMapCache import java.util.concurrent.TimeUnit - object ExitServerProtocol { - val TOPIC = "exit.topic" - // Will only be enabled if you install the Handler @Volatile private var enabled = false // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will // resolve itself when the protocol session stuff is done. - data class ExitMessage(val exitCode: Int, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class ExitMessage(val exitCode: Int) class Plugin: CordaPluginRegistry() { override val servicePlugins: List> = listOf(Service::class.java) } class Service(services: ServiceHubInternal) { - init { - services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration -> - // Just to validate we got the message - if (enabled) { - val message = msg.data.deserialize() - System.exit(message.exitCode) - } - } + services.registerProtocolInitiator(Broadcast::class, ::ExitServerHandler) enabled = true } } + + private class ExitServerHandler(val otherParty: Party) : ProtocolLogic() { + override fun call() { + // Just to validate we got the message + if (enabled) { + val message = receive(otherParty).unwrap { it } + System.exit(message.exitCode) + } + } + } + + /** * This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently * we do not support coercing numeric types in the reflective search for matching constructors. */ class Broadcast(val exitCode: Int) : ProtocolLogic() { - override val topic: String get() = TOPIC - @Suspendable override fun call(): Boolean { if (enabled) { @@ -73,7 +66,7 @@ object ExitServerProtocol { if (recipient.address is MockNetworkMapCache.MockAddress) { // Ignore } else { - send(recipient.identity, ExitMessage(exitCode, recipient.identity)) + send(recipient.identity, ExitMessage(exitCode)) } } } diff --git a/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt index af3340ce16..b3c3e92a66 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/UpdateBusinessDayProtocol.kt @@ -4,14 +4,10 @@ import co.paralleluniverse.fibers.Suspendable import com.r3corda.core.crypto.Party import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.NodeInfo -import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue -import com.r3corda.core.serialization.deserialize import com.r3corda.core.utilities.ProgressTracker import com.r3corda.demos.DemoClock import com.r3corda.node.services.api.ServiceHubInternal -import com.r3corda.protocols.HandshakeMessage import com.r3corda.testing.node.MockNetworkMapCache import java.time.LocalDate @@ -20,29 +16,28 @@ import java.time.LocalDate */ object UpdateBusinessDayProtocol { - val TOPIC = "businessday.topic" - // This is not really a HandshakeMessage but needs to be so that the send uses the default session ID. This will // resolve itself when the protocol session stuff is done. - data class UpdateBusinessDayMessage(val date: LocalDate, - override val replyToParty: Party, - override val sendSessionID: Long = random63BitValue(), - override val receiveSessionID: Long = random63BitValue()) : HandshakeMessage + data class UpdateBusinessDayMessage(val date: LocalDate) class Plugin: CordaPluginRegistry() { override val servicePlugins: List> = listOf(Service::class.java) } class Service(services: ServiceHubInternal) { - init { - services.networkService.addMessageHandler(TOPIC, DEFAULT_SESSION_ID) { msg, registration -> - val updateBusinessDayMessage = msg.data.deserialize() - (services.clock as DemoClock).updateDate(updateBusinessDayMessage.date) - } + services.registerProtocolInitiator(Broadcast::class, ::UpdateBusinessDayHandler) } } + private class UpdateBusinessDayHandler(val otherParty: Party) : ProtocolLogic() { + override fun call() { + val message = receive(otherParty).unwrap { it } + (serviceHub.clock as DemoClock).updateDate(message.date) + } + } + + class Broadcast(val date: LocalDate, override val progressTracker: ProgressTracker = Broadcast.tracker()) : ProtocolLogic() { @@ -52,8 +47,6 @@ object UpdateBusinessDayProtocol { fun tracker() = ProgressTracker(NOTIFYING) } - override val topic: String get() = TOPIC - @Suspendable override fun call(): Unit { progressTracker.currentStep = NOTIFYING @@ -67,7 +60,7 @@ object UpdateBusinessDayProtocol { if (recipient.address is MockNetworkMapCache.MockAddress) { // Ignore } else { - send(recipient.identity, UpdateBusinessDayMessage(date, recipient.identity)) + send(recipient.identity, UpdateBusinessDayMessage(date)) } } } diff --git a/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt b/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt index 33c86ae13d..1e30cb4e10 100644 --- a/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt +++ b/src/main/kotlin/com/r3corda/simulation/IRSSimulation.kt @@ -10,14 +10,16 @@ import com.r3corda.core.RunOnCallerThread import com.r3corda.core.contracts.StateAndRef import com.r3corda.core.contracts.UniqueIdentifier import com.r3corda.core.failure +import com.r3corda.core.flatMap import com.r3corda.core.node.services.linearHeadsOfType import com.r3corda.core.success import com.r3corda.core.transactions.SignedTransaction -import com.r3corda.protocols.TwoPartyDealProtocol -import com.r3corda.testing.connectProtocols +import com.r3corda.protocols.TwoPartyDealProtocol.Acceptor +import com.r3corda.protocols.TwoPartyDealProtocol.AutoOffer +import com.r3corda.protocols.TwoPartyDealProtocol.Instigator +import com.r3corda.testing.initiateSingleShotProtocol import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockIdentityService -import java.security.KeyPair import java.time.LocalDate import java.util.* @@ -73,7 +75,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten val node1: SimulatedNode = banks[i] val node2: SimulatedNode = banks[j] - val swaps: Map> = node1.services.vaultService.linearHeadsOfType() + val swaps: Map> = node1.services.vaultService.linearHeadsOfType() val theDealRef: StateAndRef = swaps.values.single() // Do we have any more days left in this deal's lifetime? If not, return. @@ -111,22 +113,19 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten // We load the IRS afresh each time because the leg parts of the structure aren't data classes so they don't // have the convenient copy() method that'd let us make small adjustments. Instead they're partly mutable. // TODO: We should revisit this in post-Excalibur cleanup and fix, e.g. by introducing an interface. - val irs = om.readValue(javaClass.getResource("trade.json")) + val irs = om.readValue(javaClass.getResource("trade.json")) irs.fixedLeg.fixedRatePayer = node1.info.identity irs.floatingLeg.floatingRatePayer = node2.info.identity - val instigator = TwoPartyDealProtocol.Instigator(node2.info.identity, notary.info.identity, irs, node1.keyPair!!) - val acceptor = TwoPartyDealProtocol.Acceptor(node1.info.identity, notary.info.identity, irs) - connectProtocols(instigator, acceptor) + val acceptorTx = node2.initiateSingleShotProtocol(Instigator::class) { Acceptor(it) }.flatMap { it.resultFuture } showProgressFor(listOf(node1, node2)) showConsensusFor(listOf(node1, node2, regulators[0])) - val instigatorFuture: ListenableFuture = node1.services.startProtocol("instigator", instigator) + val instigator = Instigator(node2.info.identity, AutoOffer(notary.info.identity, irs), node1.keyPair!!) + val instigatorTx = node1.services.startProtocol("instigator", instigator) - return Futures.transformAsync(Futures.allAsList(instigatorFuture, node2.services.startProtocol("acceptor", acceptor))) { - instigatorFuture - } + return Futures.transformAsync(Futures.allAsList(instigatorTx, acceptorTx)) { instigatorTx } } override fun iterate(): InMemoryMessagingNetwork.MessageTransfer? { diff --git a/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt b/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt index 6831acb0a9..ff9e8e3913 100644 --- a/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt +++ b/src/main/kotlin/com/r3corda/simulation/TradeSimulation.kt @@ -9,12 +9,13 @@ import com.r3corda.core.contracts.DOLLARS import com.r3corda.core.contracts.OwnableState import com.r3corda.core.contracts.`issued by` import com.r3corda.core.days +import com.r3corda.core.flatMap import com.r3corda.core.node.recordTransactions import com.r3corda.core.seconds import com.r3corda.core.transactions.SignedTransaction -import com.r3corda.protocols.TwoPartyTradeProtocol -import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC -import com.r3corda.testing.connectProtocols +import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer +import com.r3corda.protocols.TwoPartyTradeProtocol.Seller +import com.r3corda.testing.initiateSingleShotProtocol import com.r3corda.testing.node.InMemoryMessagingNetwork import java.time.Instant @@ -45,25 +46,24 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo seller.services.recordTransactions(issuance) val amount = 1000.DOLLARS - val buyerProtocol = TwoPartyTradeProtocol.Buyer( - seller.info.identity, - notary.info.identity, - amount, - CommercialPaper.State::class.java) - val sellerProtocol = TwoPartyTradeProtocol.Seller( + + val buyerFuture = buyer.initiateSingleShotProtocol(Seller::class) { + Buyer(it, notary.info.identity, amount, CommercialPaper.State::class.java) + }.flatMap { it.resultFuture } + + val sellerProtocol = Seller( buyer.info.identity, notary.info, issuance.tx.outRef(0), amount, seller.storage.myLegalIdentityKey) - connectProtocols(buyerProtocol, sellerProtocol) showConsensusFor(listOf(buyer, seller, notary)) showProgressFor(listOf(buyer, seller)) - val buyerFuture = buyer.services.startProtocol("bank.$buyerBankIndex.$TOPIC.buyer", buyerProtocol) - val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.$TOPIC.seller", sellerProtocol) + val sellerFuture = seller.services.startProtocol("bank.$sellerBankIndex.seller", sellerProtocol) return Futures.successfulAsList(buyerFuture, sellerFuture) } + } diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt b/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt index 95b0ea1191..6cd13f1d53 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/CoreTestUtils.kt @@ -4,22 +4,28 @@ package com.r3corda.testing import com.google.common.base.Throwables import com.google.common.net.HostAndPort +import com.google.common.util.concurrent.ListenableFuture +import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.contracts.StateRef import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.node.ServiceHub import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.core.random63BitValue +import com.r3corda.core.protocols.ProtocolStateMachine import com.r3corda.core.transactions.TransactionBuilder import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY -import com.r3corda.protocols.HandshakeMessage +import com.r3corda.node.internal.AbstractNode +import com.r3corda.node.services.statemachine.StateMachineManager.Change +import com.r3corda.node.utilities.AddOrRemove.ADD import com.r3corda.testing.node.MockIdentityService import com.r3corda.testing.node.MockServices +import rx.Subscriber import java.net.ServerSocket import java.security.KeyPair import java.security.PublicKey +import kotlin.reflect.KClass /** * JAVA INTEROP @@ -129,22 +135,32 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List { dsl: TransactionDSL.() -> EnforceVerifyOrFail ) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) } - /** - * Connect two protocols together for communication. Both protocols must have a property called otherParty of type Party - * which points to the other party in the communication. + * The given protocol factory will be used to initiate just one instance of a protocol of type [P] when a counterparty + * protocol requests for it using [markerClass]. + * @return Returns a [ListenableFuture] holding the single [ProtocolStateMachine] created by the request. */ -fun connectProtocols(protocol1: ProtocolLogic<*>, protocol2: ProtocolLogic<*>) { +inline fun > AbstractNode.initiateSingleShotProtocol( + markerClass: KClass<*>, + noinline protocolFactory: (Party) -> P): ListenableFuture> { + services.registerProtocolInitiator(markerClass, protocolFactory) - data class Handshake(override val replyToParty: Party, - override val sendSessionID: Long, - override val receiveSessionID: Long) : HandshakeMessage + val future = SettableFuture.create>() - val sessionId1 = random63BitValue() - val sessionId2 = random63BitValue() - protocol1.registerSession(Handshake(protocol1.otherParty, sessionId1, sessionId2)) - protocol2.registerSession(Handshake(protocol2.otherParty, sessionId2, sessionId1)) -} + val subscriber = object : Subscriber() { + override fun onNext(change: Change) { + if (change.logic is P && change.addOrRemove == ADD) { + unsubscribe() + future.set(change.logic.psm as ProtocolStateMachine) + } + } + override fun onError(e: Throwable) { + future.setException(e) + } + override fun onCompleted() {} + } -private val ProtocolLogic<*>.otherParty: Party - get() = javaClass.getDeclaredField("otherParty").apply { isAccessible = true }.get(this) as Party \ No newline at end of file + smm.changes.subscribe(subscriber) + + return future +} \ No newline at end of file diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index 2a33c31d09..13029f8d2d 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -131,6 +131,10 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, // It is used from the network visualiser tool. @Suppress("unused") val place: PhysicalLocation get() = findMyLocation()!! + fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? { + return (net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(block) + } + fun send(topic: String, target: MockNode, payload: Any) { services.networkService.send(TopicSession(topic), payload, target.info.address) }