From 299e1af15e7274ca92b51513e53d5f20dc9707be Mon Sep 17 00:00:00 2001 From: Mike Hearn Date: Tue, 16 Feb 2016 19:06:45 +0100 Subject: [PATCH] Protocol frameworks: separate the fiber object from the logic object to make it easier to compose subprotocols together. --- .../protocols/TwoPartyTradeProtocol.kt | 40 +++--- .../kotlin/core/messaging/StateMachines.kt | 131 +++++++++++------- .../core/node/TimestamperNodeService.kt | 25 +++- .../messaging/TwoPartyTradeProtocolTests.kt | 7 +- .../core/node/TimestamperNodeServiceTest.kt | 5 +- 5 files changed, 122 insertions(+), 86 deletions(-) diff --git a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt index a7cf5e8af8..47f6dc748a 100644 --- a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt +++ b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt @@ -18,7 +18,7 @@ import core.crypto.SecureHash import core.crypto.signWithECDSA import core.messaging.* import core.node.DataVendingService -import core.node.TimestamperClient +import core.node.TimestampingProtocol import core.utilities.trace import java.security.KeyPair import java.security.PublicKey @@ -55,16 +55,14 @@ object TwoPartyTradeProtocol { otherSide: SingleMessageRecipient, assetToSell: StateAndRef, price: Amount, myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture { val seller = Seller(otherSide, timestampingAuthority, assetToSell, price, myKeyPair, buyerSessionID) - smm.add("$TRADE_TOPIC.seller", seller) - return seller.resultFuture + return smm.add("$TRADE_TOPIC.seller", seller) } fun runBuyer(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode, otherSide: SingleMessageRecipient, acceptablePrice: Amount, typeToBuy: Class, sessionID: Long): ListenableFuture { val buyer = Buyer(otherSide, timestampingAuthority.identity, acceptablePrice, typeToBuy, sessionID) - smm.add("$TRADE_TOPIC.buyer", buyer) - return buyer.resultFuture + return smm.add("$TRADE_TOPIC.buyer", buyer) } class UnacceptablePriceException(val givenPrice: Amount) : Exception() @@ -88,14 +86,14 @@ object TwoPartyTradeProtocol { val assetToSell: StateAndRef, val price: Amount, val myKeyPair: KeyPair, - val buyerSessionID: Long) : ProtocolStateMachine() { + val buyerSessionID: Long) : ProtocolLogic() { @Suspendable override fun call(): SignedTransaction { val partialTX: SignedTransaction = receiveAndCheckProposedTransaction() // These two steps could be done in parallel, in theory. Our framework doesn't support that yet though. val ourSignature = signWithOurKey(partialTX) - val tsaSig = timestamp(partialTX) + val tsaSig = subProtocol(TimestampingProtocol(timestampingAuthority, partialTX.txBits)) val signedTransaction = sendSignatures(partialTX, ourSignature, tsaSig) @@ -103,7 +101,7 @@ object TwoPartyTradeProtocol { } @Suspendable - open fun receiveAndCheckProposedTransaction(): SignedTransaction { + private fun receiveAndCheckProposedTransaction(): SignedTransaction { val sessionID = random63BitValue() // Make the first message we'll send to kick off the protocol. @@ -137,7 +135,7 @@ object TwoPartyTradeProtocol { } @Suspendable - open fun checkDependencies(txToCheck: SignedTransaction) { + private fun checkDependencies(txToCheck: SignedTransaction) { val toVerify = HashSet() val alreadyVerified = HashSet() val downloadedSignedTxns = ArrayList() @@ -249,15 +247,10 @@ object TwoPartyTradeProtocol { } } - open fun signWithOurKey(partialTX: SignedTransaction) = myKeyPair.signWithECDSA(partialTX.txBits) + private fun signWithOurKey(partialTX: SignedTransaction) = myKeyPair.signWithECDSA(partialTX.txBits) @Suspendable - open fun timestamp(partialTX: SignedTransaction): DigitalSignature.LegallyIdentifiable { - return TimestamperClient(this, timestampingAuthority).timestamp(partialTX.txBits) - } - - @Suspendable - open fun sendSignatures(partialTX: SignedTransaction, ourSignature: DigitalSignature.WithKey, + private fun sendSignatures(partialTX: SignedTransaction, ourSignature: DigitalSignature.WithKey, tsaSig: DigitalSignature.LegallyIdentifiable): SignedTransaction { val fullySigned = partialTX + tsaSig + ourSignature @@ -272,7 +265,8 @@ object TwoPartyTradeProtocol { val timestampingAuthority: Party, val acceptablePrice: Amount, val typeToBuy: Class, - val sessionID: Long) : ProtocolStateMachine() { + val sessionID: Long) : ProtocolLogic() { + @Suspendable override fun call(): SignedTransaction { val tradeRequest = receiveAndValidateTradeRequest() @@ -289,9 +283,9 @@ object TwoPartyTradeProtocol { } @Suspendable - open fun receiveAndValidateTradeRequest(): SellerTradeInfo { + private fun receiveAndValidateTradeRequest(): SellerTradeInfo { // Wait for a trade request to come in on our pre-provided session ID. - val maybeTradeRequest = receive(TRADE_TOPIC, sessionID, SellerTradeInfo::class.java) + val maybeTradeRequest = receive(TRADE_TOPIC, sessionID) val tradeRequest = maybeTradeRequest.validate { // What is the seller trying to sell us? @@ -315,15 +309,15 @@ object TwoPartyTradeProtocol { } @Suspendable - open fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller { + private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller { logger.trace { "Sending partially signed transaction to seller" } // TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx. - return sendAndReceive(TRADE_TOPIC, otherSide, theirSessionID, sessionID, stx, SignaturesFromSeller::class.java).validate { it } + return sendAndReceive(TRADE_TOPIC, otherSide, theirSessionID, sessionID, stx).validate { it } } - open fun signWithOurKeys(cashSigningPubKeys: List, ptx: TransactionBuilder): SignedTransaction { + private fun signWithOurKeys(cashSigningPubKeys: List, ptx: TransactionBuilder): SignedTransaction { // Now sign the transaction with whatever keys we need to move the cash. for (k in cashSigningPubKeys) { val priv = serviceHub.keyManagementService.toPrivate(k) @@ -338,7 +332,7 @@ object TwoPartyTradeProtocol { return stx } - open fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair> { + private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair> { val ptx = TransactionBuilder() // Add input and output states for the movement of cash, by using the Cash contract to generate the states. val wallet = serviceHub.walletService.currentWallet diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt index 8c284d45de..c6648a22bc 100644 --- a/src/main/kotlin/core/messaging/StateMachines.kt +++ b/src/main/kotlin/core/messaging/StateMachines.kt @@ -32,7 +32,6 @@ import java.io.ByteArrayOutputStream import java.io.PrintWriter import java.io.StringWriter import java.util.* -import java.util.concurrent.Callable import java.util.concurrent.Executor import javax.annotation.concurrent.ThreadSafe @@ -62,7 +61,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) private val checkpointsMap = serviceHub.storageService.getMap("state machines") // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines // property. - private val _stateMachines = Collections.synchronizedList(ArrayList>()) + private val _stateMachines = Collections.synchronizedList(ArrayList>()) // This is a workaround for something Gradle does to us during unit tests. It replaces stderr with its own // class that inserts itself into a ThreadLocal. That then gets caught in fiber serialisation, which we don't @@ -73,10 +72,11 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) // ever recover. val checkpointing: Boolean get() = !System.err.javaClass.name.contains("LinePerThreadBufferingOutputStream") - /** Returns a snapshot of the currently registered state machines. */ - val stateMachines: List> get() { + /** Returns a list of all state machines executing the given protocol logic at the top level (subprotocols do not count) */ + fun findStateMachines(klass: Class>): List, ListenableFuture>> { synchronized(_stateMachines) { - return ArrayList(_stateMachines) + @Suppress("UNCHECKED_CAST") + return _stateMachines.filterIsInstance(klass).map { it to (it.psm as ProtocolStateMachine).resultFuture } } } @@ -113,7 +113,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) // Grab the Kryo engine configured by Quasar for its own stuff, and then do our own configuration on top // so we can deserialised the nested stream that holds the fiber. val psm = deserializeFiber(checkpoint.serialisedFiber) - _stateMachines.add(psm) + _stateMachines.add(psm.logic) val logger = LoggerFactory.getLogger(checkpoint.loggerName) val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType) val topic = checkpoint.awaitingTopic @@ -155,12 +155,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is * restarted with checkpointed state machines in the storage service. */ - fun > add(loggerName: String, fiber: T): T { + fun add(loggerName: String, logic: ProtocolLogic): ListenableFuture { val logger = LoggerFactory.getLogger(loggerName) + val fiber = ProtocolStateMachine(logic) iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) { it.start() } - return fiber + return fiber.resultFuture } private fun persistCheckpoint(prevCheckpointKey: SecureHash?, new: ByteArray): SecureHash { @@ -206,7 +207,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) // We're back! Check if the fiber is finished and if so, clean up. if (psm.isTerminated) { - _stateMachines.remove(psm) + _stateMachines.remove(psm.logic) checkpointsMap.remove(prevCheckpointKey) } } @@ -236,35 +237,83 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor()) /** - * The base class that should be used by any object that wishes to act as a protocol state machine. A PSM is - * a kind of "fiber", and a fiber in turn is a bit like a thread, but a thread that can be suspended to the heap, - * serialised to disk, and resumed on demand. + * A sub-class of [ProtocolLogic] implements a protocol flow using direct, straight line blocking code. Thus you + * can write complex protocol logic in an ordinary fashion, without having to think about callbacks, restarting after + * a node crash, how many instances of your protocol there are running and so on. * - * Sub-classes should override the [call] method and return whatever the final result of the protocol is. Inside the - * call method, the rules of normal object oriented programming are a little different: + * Invoking the network will cause the call stack to be suspended onto the heap and then serialized to a database using + * the Quasar fibers framework. Because of this, if you need access to data that might change over time, you should + * request it just-in-time via the [serviceHub] property which is provided. Don't try and keep data you got from a + * service across calls to send/receive/sendAndReceive because the world might change in arbitrary ways out from + * underneath you, for instance, if the node is restarted or reconfigured! * - * - You can call send/receive/sendAndReceive in order to suspend the state machine and request network interaction. - * This does not block a thread and when a state machine is suspended like this, it will be serialised and written - * to stable storage. That means all objects on the stack and referenced from fields must be serialisable as well - * (with Kryo, so they don't have to implement the Java Serializable interface). The state machine may be resumed - * at some arbitrary later point. - * - Because of this, if you need access to data that might change over time, you should request it just-in-time - * via the [serviceHub] property which is provided. Don't try and keep data you got from a service across calls to - * send/receive/sendAndReceive because the world might change in arbitrary ways out from underneath you, for instance, - * if the node is restarted or reconfigured! + * Additionally, be aware of what data you pin either via the stack or in your [ProtocolLogic] implementation. Very large + * objects or datasets will hurt performance by increasing the amount of data stored in each checkpoint. * - * The result of the [call] method can be obtained by using the [resultFuture] property, which is a [ListenableFuture] - * and will let you register a callback to be informed when the protocol has completed. Note that the PSM class is also - * a future, but not a listenable one. - * - * Once created, a PSM should be passed to a [StateMachineManager] which will start it and manage its execution. + * If you'd like to use another ProtocolLogic class as a component of your own, construct it on the fly and then pass + * it to the [subProtocol] method. It will return the result of that protocol when it completes. */ -abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiberScheduler), Callable { +abstract class ProtocolLogic { + /** Reference to the [Fiber] instance that is the top level controller for the entire flow. */ + lateinit var psm: ProtocolStateMachine<*> + + /** This is where you should log things to. */ + val logger: Logger get() = psm.logger + /** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */ + val serviceHub: ServiceHub get() = psm.serviceHub + + // Kotlin helpers that allow the use of generic types. + inline fun sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long, + sessionIDForReceive: Long, obj: Any): UntrustworthyData { + return psm.sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, obj, T::class.java) + } + inline fun receive(topic: String, sessionIDForReceive: Long): UntrustworthyData { + return psm.receive(topic, sessionIDForReceive, T::class.java) + } + @Suspendable fun send(topic: String, destination: MessageRecipients, sessionID: Long, obj: Any) { + psm.send(topic, destination, sessionID, obj) + } + + /** + * Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the + * [ProtocolStateMachine] and then calling the [call] method. + */ + @Suspendable fun subProtocol(subLogic: ProtocolLogic): R { + subLogic.psm = psm + return subLogic.call() + } + + @Suspendable + abstract fun call(): T +} + +/** + * A ProtocolStateMachine instance is a suspendable fiber that delegates all actual logic to a [ProtocolLogic] instance. + * For any given flow there is only one PSM, even if that protocol invokes subprotocols. + * + * These classes are created by the [StateMachineManager] when a new protocol is started at the topmost level. If + * a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost + * logic element gets to return the value that the entire state machine resolves to. + */ +class ProtocolStateMachine(val logic: ProtocolLogic) : Fiber("protocol", SameThreadFiberScheduler) { // These fields shouldn't be serialised, so they are marked @Transient. @Transient private var suspendFunc: ((result: FiberRequest, serFiber: ByteArray) -> Unit)? = null @Transient private var resumeWithObject: Any? = null @Transient lateinit var serviceHub: ServiceHub - @Transient protected lateinit var logger: Logger + @Transient lateinit var logger: Logger + + init { + logic.psm = this + } + + fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?, logger: Logger, + suspendFunc: (FiberRequest, ByteArray) -> Unit) { + this.suspendFunc = suspendFunc + this.logger = logger + this.resumeWithObject = withObject + this.serviceHub = serviceHub + } + @Transient private var _resultFuture: SettableFuture? = SettableFuture.create() /** This future will complete when the call method returns. */ @@ -276,21 +325,10 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiberSch } } - fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?, logger: Logger, - suspendFunc: (FiberRequest, ByteArray) -> Unit) { - this.suspendFunc = suspendFunc - this.logger = logger - this.resumeWithObject = withObject - this.serviceHub = serviceHub - } - - // This line may look useless, but it's needed to convince the Quasar bytecode rewriter to do the right thing. - @Suspendable override abstract fun call(): R - @Suspendable @Suppress("UNCHECKED_CAST") override fun run(): R { try { - val result = call() + val result = logic.call() if (result != null) _resultFuture?.set(result) return result @@ -335,15 +373,6 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiberSch val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, obj) Fiber.parkAndSerialize { fiber, writer -> suspendFunc!!(result, writer.write(fiber)) } } - - // Kotlin helpers that allow the use of generic types. - inline fun sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long, - sessionIDForReceive: Long, obj: Any): UntrustworthyData { - return sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, obj, T::class.java) - } - inline fun receive(topic: String, sessionIDForReceive: Long): UntrustworthyData { - return receive(topic, sessionIDForReceive, T::class.java) - } } /** diff --git a/src/main/kotlin/core/node/TimestamperNodeService.kt b/src/main/kotlin/core/node/TimestamperNodeService.kt index f9ecc56b98..1b9a03d0df 100644 --- a/src/main/kotlin/core/node/TimestamperNodeService.kt +++ b/src/main/kotlin/core/node/TimestamperNodeService.kt @@ -13,10 +13,7 @@ import co.paralleluniverse.fibers.Suspendable import core.* import core.crypto.DigitalSignature import core.crypto.signWithECDSA -import core.messaging.LegallyIdentifiableNode -import core.messaging.MessageRecipients -import core.messaging.MessagingService -import core.messaging.ProtocolStateMachine +import core.messaging.* import core.serialization.SerializedBytes import core.serialization.deserialize import core.serialization.serialize @@ -95,7 +92,6 @@ class TimestamperNodeService(private val net: MessagingService, } } -@ThreadSafe class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val node: LegallyIdentifiableNode) : TimestamperService { override val identity: Party = node.identity @@ -116,3 +112,22 @@ class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val no } } +class TimestampingProtocol(private val node: LegallyIdentifiableNode, + private val wtxBytes: SerializedBytes) : ProtocolLogic() { + @Suspendable + override fun call(): DigitalSignature.LegallyIdentifiable { + val sessionID = random63BitValue() + val replyTopic = "${TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID" + val req = TimestampingMessages.Request(wtxBytes, serviceHub.networkService.myAddress, replyTopic) + + val maybeSignature = sendAndReceive( + TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req) + + // Check that the timestamping authority gave us back a valid signature and didn't break somehow + maybeSignature.validate { sig -> + sig.verifyWithECDSA(wtxBytes) + return sig + } + } +} + diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index 9840f49577..4101b655c4 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -168,8 +168,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { ) // Find the future representing the result of this state machine again. - assertEquals(1, smm.stateMachines.size) - var bobFuture = smm.stateMachines.filterIsInstance().first().resultFuture + var bobFuture = smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second // Let Bob process his mailbox. assertTrue(bobsNode.pump(false)) @@ -179,7 +178,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { txns.add(stx.tx) verify() - assertTrue(smm.stateMachines.isEmpty()) + assertTrue(smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).isEmpty()) } } @@ -239,7 +238,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { @Test fun `dependency with error`() { transactionGroupFor { - val (bobsWallet, fakeTxns) = fillUp(withError = true) + val bobsWallet = fillUp(withError = true).first val (alicesAddress, alicesNode) = makeNode(inBackground = true) val (bobsAddress, bobsNode) = makeNode(inBackground = true) diff --git a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt index 42a08dbae8..0e611e403c 100644 --- a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt +++ b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt @@ -58,10 +58,10 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() { service = TimestamperNodeService(serviceNode.second, Party("Unit test suite", ALICE), ALICE_KEY) } - class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine() { + class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolLogic() { @Suspendable override fun call(): Boolean { - val client = TimestamperClient(this, server) + val client = TimestamperClient(psm, server) val ptx = TransactionBuilder().apply { addInputState(StateRef(SecureHash.randomSHA256(), 0)) addOutputState(100.DOLLARS.CASH) @@ -82,7 +82,6 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() { val logName = TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC val psm = TestPSM(mockServices.networkMapService.timestampingNodes[0], clock.instant()) smm.add(logName, psm) - psm } assertTrue(psm.isDone) }