From 78849f44d2dc1841f2d016d964c178be94c3f97e Mon Sep 17 00:00:00 2001 From: Mike Hearn Date: Fri, 8 Jan 2016 17:26:02 +0100 Subject: [PATCH] Protocols: simplify the two party (dvp) protocol some more, now that we've switched to Quasar. There's no longer any need to define InitialArgs objects. --- .../protocols/TwoPartyTradeProtocol.kt | 114 ++++++++---------- .../kotlin/core/messaging/StateMachines.kt | 42 ++++--- .../core/node/TimestamperNodeService.kt | 2 +- .../messaging/TwoPartyTradeProtocolTests.kt | 61 +++++----- .../core/node/TimestamperNodeServiceTest.kt | 6 +- 5 files changed, 106 insertions(+), 119 deletions(-) diff --git a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt index 8e1b3f42d0..f9b5905d4b 100644 --- a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt +++ b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt @@ -9,6 +9,7 @@ package contracts.protocols import co.paralleluniverse.fibers.Suspendable +import com.google.common.util.concurrent.ListenableFuture import contracts.Cash import contracts.sumCashBy import core.* @@ -23,8 +24,6 @@ import java.security.KeyPair import java.security.PublicKey import java.time.Instant -// TODO: Get rid of the "initial args" concept and just use the class c'tors, now we are using Quasar. - /** * This asset trading protocol implements a "delivery vs payment" type swap. It has two parties (B and S for buyer * and seller) and the following steps: @@ -49,64 +48,46 @@ import java.time.Instant * * To see an example of how to use this class, look at the unit tests. */ -abstract class TwoPartyTradeProtocol { - class SellerInitialArgs( - val assetToSell: StateAndRef, - val price: Amount, - val myKeyPair: KeyPair, - val buyerSessionID: Long - ) +object TwoPartyTradeProtocol { + val TRADE_TOPIC = "com.r3cev.protocols.trade" - abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller - - class BuyerInitialArgs( - val acceptablePrice: Amount, - val typeToBuy: Class, - val sessionID: Long - ) - - abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer - - abstract class Buyer : ProtocolStateMachine>() - abstract class Seller : ProtocolStateMachine>() - - companion object { - @JvmStatic fun create(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode): TwoPartyTradeProtocol { - return TwoPartyTradeProtocolImpl(smm, timestampingAuthority) - } - } -} - -private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, - private val timestampingAuthority: LegallyIdentifiableNode) : TwoPartyTradeProtocol() { - companion object { - val TRADE_TOPIC = "com.r3cev.protocols.trade" + fun runSeller(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode, + otherSide: SingleMessageRecipient, assetToSell: StateAndRef, price: Amount, + myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture> { + val seller = Seller(otherSide, timestampingAuthority, assetToSell, price, myKeyPair, buyerSessionID) + smm.add("$TRADE_TOPIC.seller", seller) + return seller.resultFuture } - // This object is serialised to the network and is the first protocol message the seller sends to the buyer. - class SellerTradeInfo( - val assetForSale: StateAndRef, - val price: Amount, - val sellerOwnerKey: PublicKey, - val sessionID: Long - ) + fun runBuyer(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode, + otherSide: SingleMessageRecipient, acceptablePrice: Amount, typeToBuy: Class, + sessionID: Long): ListenableFuture> { + val buyer = Buyer(otherSide, timestampingAuthority.identity, acceptablePrice, typeToBuy, sessionID) + smm.add("$TRADE_TOPIC.buyer", buyer) + return buyer.resultFuture + } - class SellerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: LegallyIdentifiableNode) : Seller() { + class Seller(val otherSide: SingleMessageRecipient, + val timestampingAuthority: LegallyIdentifiableNode, + val assetToSell: StateAndRef, + val price: Amount, + val myKeyPair: KeyPair, + val buyerSessionID: Long) : ProtocolStateMachine>() { @Suspendable - override fun call(args: SellerInitialArgs): Pair { + override fun call(): Pair { val sessionID = random63BitValue() // Make the first message we'll send to kick off the protocol. - val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID) + val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID) - val partialTX = sendAndReceive(TRADE_TOPIC, otherSide, args.buyerSessionID, sessionID, hello) + val partialTX = sendAndReceive(TRADE_TOPIC, otherSide, buyerSessionID, sessionID, hello) logger.trace { "Received partially signed transaction" } partialTX.verifySignatures() val wtx: WireTransaction = partialTX.txBits.deserialize() requireThat { - "transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(args.myKeyPair.public) == args.price) + "transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(myKeyPair.public) == price) // There are all sorts of funny games a malicious secondary might play here, we should fix them: // // - This tx may attempt to send some assets we aren't intending to sell to the secondary, if @@ -122,7 +103,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, // Sign with our key and get the timestamping authorities key as well. // These two steps could be done in parallel, in theory. - val ourSignature = args.myKeyPair.signWithECDSA(partialTX.txBits) + val ourSignature = myKeyPair.signWithECDSA(partialTX.txBits) val tsaSig = TimestamperClient(this, timestampingAuthority).timestamp(partialTX.txBits) val fullySigned = partialTX.withAdditionalSignature(tsaSig).withAdditionalSignature(ourSignature) @@ -130,23 +111,36 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, logger.trace { "Built finished transaction, sending back to secondary!" } - send(TRADE_TOPIC, otherSide, args.buyerSessionID, fullySigned) + send(TRADE_TOPIC, otherSide, buyerSessionID, fullySigned) return Pair(wtx, fullySigned.verifyToLedgerTransaction(serviceHub.identityService)) } } - class UnacceptablePriceException(val givenPrice: Amount) : Exception() - class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { + // This object is serialised to the network and is the first protocol message the seller sends to the buyer. + private class SellerTradeInfo( + val assetForSale: StateAndRef, + val price: Amount, + val sellerOwnerKey: PublicKey, + val sessionID: Long + ) + + + private class UnacceptablePriceException(val givenPrice: Amount) : Exception() + private class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName" } // The buyer's side of the protocol. See note above Seller to learn about the caveats here. - class BuyerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: Party) : Buyer() { + class Buyer(val otherSide: SingleMessageRecipient, + val timestampingAuthority: Party, + val acceptablePrice: Amount, + val typeToBuy: Class, + val sessionID: Long) : ProtocolStateMachine>() { @Suspendable - override fun call(args: BuyerInitialArgs): Pair { + override fun call(): Pair { // Wait for a trade request to come in on our pre-provided session ID. - val tradeRequest = receive(TRADE_TOPIC, args.sessionID) + val tradeRequest = receive(TRADE_TOPIC, sessionID) // What is the seller trying to sell us? val assetTypeName = tradeRequest.assetForSale.state.javaClass.name @@ -154,10 +148,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, // Check the start message for acceptability. check(tradeRequest.sessionID > 0) - if (tradeRequest.price > args.acceptablePrice) + if (tradeRequest.price > acceptablePrice) throw UnacceptablePriceException(tradeRequest.price) - if (!args.typeToBuy.isInstance(tradeRequest.assetForSale.state)) - throw AssetMismatchException(args.typeToBuy.name, assetTypeName) + if (!typeToBuy.isInstance(tradeRequest.assetForSale.state)) + throw AssetMismatchException(typeToBuy.name, assetTypeName) // TODO: Either look up the stateref here in our local db, or accept a long chain of states and // validate them to audit the other side and ensure it actually owns the state we are being offered! @@ -198,7 +192,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, // TODO: Protect against a malicious buyer sending us back a different transaction to the one we built. val fullySigned = sendAndReceive(TRADE_TOPIC, otherSide, tradeRequest.sessionID, - args.sessionID, stx) + sessionID, stx) logger.trace { "Got fully signed transaction, verifying ... "} @@ -209,12 +203,4 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, return Pair(fullySigned.tx, ltx) } } - - override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller { - return smm.add(args, "$TRADE_TOPIC.seller", SellerImpl(otherSide, timestampingAuthority)) - } - - override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer { - return smm.add(args, "$TRADE_TOPIC.buyer", BuyerImpl(otherSide, timestampingAuthority.identity)) - } } \ No newline at end of file diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt index 207433725d..7eadc3e087 100644 --- a/src/main/kotlin/core/messaging/StateMachines.kt +++ b/src/main/kotlin/core/messaging/StateMachines.kt @@ -29,6 +29,7 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import java.io.ByteArrayOutputStream import java.util.* +import java.util.concurrent.Callable import java.util.concurrent.Executor import javax.annotation.concurrent.ThreadSafe @@ -57,10 +58,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) private val checkpointsMap = serviceHub.storageService.getMap("state machines") // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines // property. - private val _stateMachines = Collections.synchronizedList(ArrayList>()) + private val _stateMachines = Collections.synchronizedList(ArrayList>()) /** Returns a snapshot of the currently registered state machines. */ - val stateMachines: List> get() { + val stateMachines: List> get() { synchronized(_stateMachines) { return ArrayList(_stateMachines) } @@ -110,10 +111,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) } } - private fun deserializeFiber(bits: ByteArray): ProtocolStateMachine<*, *> { + private fun deserializeFiber(bits: ByteArray): ProtocolStateMachine<*> { val deserializer = Fiber.getFiberSerializer() as KryoSerializer val kryo = createKryo(deserializer.kryo) - val psm = kryo.readClassAndObject(Input(bits)) as ProtocolStateMachine<*, *> + val psm = kryo.readClassAndObject(Input(bits)) as ProtocolStateMachine<*> return psm } @@ -123,9 +124,9 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is * restarted with checkpointed state machines in the storage service. */ - fun , I> add(initialArgs: I, loggerName: String, fiber: T): T { + fun > add(loggerName: String, fiber: T): T { val logger = LoggerFactory.getLogger(loggerName) - iterateStateMachine(fiber, serviceHub.networkService, logger, initialArgs, null) { + iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) { it.start() } return fiber @@ -141,8 +142,8 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) return key } - private fun iterateStateMachine(psm: ProtocolStateMachine<*, *>, net: MessagingService, logger: Logger, - obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*, *>) -> Unit) { + private fun iterateStateMachine(psm: ProtocolStateMachine<*>, net: MessagingService, logger: Logger, + obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*>) -> Unit) { val onSuspend = fun(request: FiberRequest, serFiber: ByteArray) { // We have a request to do something: send, receive, or send-and-receive. if (request is FiberRequest.ExpectingResponse<*>) { @@ -181,7 +182,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) } } - private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*,*>, + private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*>, responseType: Class<*>, topic: String, prevCheckpointKey: SecureHash?, serialisedFiber: ByteArray) { val checkpoint = Checkpoint(serialisedFiber, logger.name, topic, responseType.name) @@ -201,8 +202,9 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor()) /** - * The base class that should be used by any object that wishes to act as a protocol state machine. The type variable - * C is the type of the initial arguments. R is the type of the return. + * The base class that should be used by any object that wishes to act as a protocol state machine. A PSM is + * a kind of "fiber", and a fiber in turn is a bit like a thread, but a thread that can be suspended to the heap, + * serialised to disk, and resumed on demand. * * Sub-classes should override the [call] method and return whatever the final result of the protocol is. Inside the * call method, the rules of normal object oriented programming are a little different: @@ -216,11 +218,15 @@ object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler" * via the [serviceHub] property which is provided. Don't try and keep data you got from a service across calls to * send/receive/sendAndReceive because the world might change in arbitrary ways out from underneath you, for instance, * if the node is restarted or reconfigured! - * - Don't pass initial data in using a constructor. This object will be instantiated using reflection so you cannot - * define your own constructor. Instead define a separate class that holds your initial arguments, and take it as - * the argument to [call]. + * + * Note that the result of the [call] method can be obtained in a couple of different ways. One is to call the get + * method, as the PSM is a [Future]. But that will block the calling thread until the result is ready, which may not + * be what you want (unless you know it's finished already). So you can also use the [resultFuture] property, which is + * a [ListenableFuture] and will let you register a callback. + * + * Once created, a PSM should be passed to a [StateMachineManager] which will start it and manage its execution. */ -abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiberScheduler) { +abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiberScheduler), Callable { // These fields shouldn't be serialised, so they are marked @Transient. @Transient private var suspendFunc: ((result: FiberRequest, serFiber: ByteArray) -> Unit)? = null @Transient private var resumeWithObject: Any? = null @@ -245,12 +251,12 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiber this.serviceHub = serviceHub } - @Suspendable - abstract fun call(args: C): R + // This line may look useless, but it's needed to convince the Quasar bytecode rewriter to do the right thing. + @Suspendable override abstract fun call(): R @Suspendable @Suppress("UNCHECKED_CAST") override fun run(): R { - val result = call(resumeWithObject as C) + val result = call() if (result != null) (resultFuture as SettableFuture).set(result) return result diff --git a/src/main/kotlin/core/node/TimestamperNodeService.kt b/src/main/kotlin/core/node/TimestamperNodeService.kt index 21f4ced7b8..b7fcf63bb9 100644 --- a/src/main/kotlin/core/node/TimestamperNodeService.kt +++ b/src/main/kotlin/core/node/TimestamperNodeService.kt @@ -107,7 +107,7 @@ class TimestamperNodeService(private val net: MessagingService, } @ThreadSafe -class TimestamperClient(private val psm: ProtocolStateMachine<*, *>, private val node: LegallyIdentifiableNode) : TimestamperService { +class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val node: LegallyIdentifiableNode) : TimestamperService { override val identity: Party = node.identity @Suspendable diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index 792b73f779..26ba415f25 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -71,32 +71,29 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { net = bobsNode ) - val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread), timestamper) - val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread), timestamper) - val buyerSessionID = random63BitValue() - val aliceResult = tpSeller.runSeller( + val aliceResult = TwoPartyTradeProtocol.runSeller( + StateMachineManager(alicesServices, backgroundThread), + timestamper, bobsAddress, - TwoPartyTradeProtocol.SellerInitialArgs( - lookup("alice's paper"), - 1000.DOLLARS, - ALICE_KEY, - buyerSessionID - ) + lookup("alice's paper"), + 1000.DOLLARS, + ALICE_KEY, + buyerSessionID ) - val bobResult = tpBuyer.runBuyer( + val bobResult = TwoPartyTradeProtocol.runBuyer( + StateMachineManager(bobsServices, backgroundThread), + timestamper, alicesAddress, - TwoPartyTradeProtocol.BuyerInitialArgs( - 1000.DOLLARS, - CommercialPaper.State::class.java, - buyerSessionID - ) + 1000.DOLLARS, + CommercialPaper.State::class.java, + buyerSessionID ) - assertEquals(aliceResult.resultFuture.get(), bobResult.resultFuture.get()) + assertEquals(aliceResult.get(), bobResult.get()) - txns.add(aliceResult.resultFuture.get().second) + txns.add(aliceResult.get().second) verify() } backgroundThread.shutdown() @@ -128,28 +125,26 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { storage = bobsStorage ) - val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, MoreExecutors.directExecutor()), timestamper.first) val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor()) - val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer, timestamper.first) val buyerSessionID = random63BitValue() - tpSeller.runSeller( + TwoPartyTradeProtocol.runSeller( + StateMachineManager(alicesServices, MoreExecutors.directExecutor()), + timestamper.first, bobsAddress, - TwoPartyTradeProtocol.SellerInitialArgs( - lookup("alice's paper"), - 1000.DOLLARS, - ALICE_KEY, - buyerSessionID - ) + lookup("alice's paper"), + 1000.DOLLARS, + ALICE_KEY, + buyerSessionID ) - tpBuyer.runBuyer( + TwoPartyTradeProtocol.runBuyer( + smmBuyer, + timestamper.first, alicesAddress, - TwoPartyTradeProtocol.BuyerInitialArgs( - 1000.DOLLARS, - CommercialPaper.State::class.java, - buyerSessionID - ) + 1000.DOLLARS, + CommercialPaper.State::class.java, + buyerSessionID ) // Everything is on this thread so we can now step through the protocol one step at a time. diff --git a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt index dcb6ed2d2d..9012692d7d 100644 --- a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt +++ b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt @@ -54,9 +54,9 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() { service = TimestamperNodeService(serviceNode.second, Party("Unit test suite", ALICE), ALICE_KEY) } - class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine() { + class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine() { @Suspendable - override fun call(args: Any?): Boolean { + override fun call(): Boolean { val client = TimestamperClient(this, server) val ptx = TransactionBuilder().apply { addInputState(ContractStateRef(SecureHash.randomSHA256(), 0)) @@ -77,7 +77,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() { val smm = StateMachineManager(MockServices(net = myNode.second), RunOnCallerThread) val logName = TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC val psm = TestPSM(myNode.second.networkMap.timestampingNodes[0], clock.instant()) - smm.add(serviceNode.first, logName, psm) + smm.add(logName, psm) psm } assertTrue(psm.isDone)