From 0ca47156bcb4a97d21b9aa3771df5d3d505c8f5f Mon Sep 17 00:00:00 2001 From: Mike Hearn Date: Mon, 7 Dec 2015 19:45:53 +0100 Subject: [PATCH] Trading protocol work in progress --- build.gradle | 20 ++ src/main/kotlin/contracts/Cash.kt | 12 +- src/main/kotlin/contracts/CommercialPaper.kt | 5 +- .../protocols/TwoPartyTradeProtocolImpl.kt | 251 ++++++++++++++++++ src/main/kotlin/core/Utils.kt | 13 + .../kotlin/core/messaging/InMemoryNetwork.kt | 8 +- src/main/kotlin/core/messaging/Messaging.kt | 36 ++- src/main/kotlin/core/utilities/Logging.kt | 4 +- .../continuations/ContinuationsSupport.kt | 31 +++ .../continuations/ProtocolStateMachines.kt | 118 ++++++++ .../core/messaging/InMemoryMessagingTests.kt | 51 ++-- .../messaging/TwoPartyTradeProtocolTests.kt | 73 +++++ src/test/kotlin/core/testutils/TestUtils.kt | 6 +- 13 files changed, 587 insertions(+), 41 deletions(-) create mode 100644 src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt create mode 100644 src/main/kotlin/core/utilities/continuations/ContinuationsSupport.kt create mode 100644 src/main/kotlin/core/utilities/continuations/ProtocolStateMachines.kt create mode 100644 src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt diff --git a/build.gradle b/build.gradle index e7a03c7352..f681dbf1d2 100644 --- a/build.gradle +++ b/build.gradle @@ -24,6 +24,9 @@ buildscript { repositories { mavenCentral() + maven { + url 'http://oss.sonatype.org/content/repositories/snapshots' + } jcenter() } @@ -39,6 +42,23 @@ dependencies { // Logging compile "org.slf4j:slf4j-jdk14:1.7.13" + // For the continuations in the state machine tests. Note: JavaFlow is old and unmaintained but still seems to work + // just fine, once the patch here is applied to update it to a Java8 compatible asm: + // + // https://github.com/playframework/play1/commit/e0e28e6780a48c000e7ed536962f1f284cef9437 + // + // Obviously using this year-old upload to Maven Central by the Maven Play Plugin team is a short term hack for + // experimenting. Using this for real would mean forking JavaFlow and taking over maintenance (luckily it's small + // and Java is stable, so this is unlikely to be a big burden). We have to manually force an in-place upgrade to + // asm 5.0.3 here (javaflow wants 5.0.2) in order to avoid version conflicts. This is also something that should be + // fixed in any fork. Sadly, even Jigsaw doesn't solve this problem out of the box. + compile "com.google.code.maven-play-plugin.org.apache.commons:commons-javaflow:1590792-patched-play-1.3.0" + compile "org.ow2.asm:asm:5.0.3" + compile "org.ow2.asm:asm-analysis:5.0.3" + compile "org.ow2.asm:asm-tree:5.0.3" + compile "org.ow2.asm:asm-commons:5.0.3" + compile "org.ow2.asm:asm-util:5.0.3" + // For visualisation compile "org.graphstream:gs-core:1.3" compile "org.graphstream:gs-ui:1.3" diff --git a/src/main/kotlin/contracts/Cash.kt b/src/main/kotlin/contracts/Cash.kt index 53597ab0cb..3779033614 100644 --- a/src/main/kotlin/contracts/Cash.kt +++ b/src/main/kotlin/contracts/Cash.kt @@ -56,10 +56,12 @@ class Cash : Contract { val amount: Amount, /** There must be a MoveCommand signed by this key to claim the amount */ - val owner: PublicKey - ) : ContractState { + override val owner: PublicKey + ) : OwnableState { override val programRef = CASH_PROGRAM_ID override fun toString() = "Cash($amount at $deposit owned by $owner)" + + override fun withNewOwner(newOwner: PublicKey) = Pair(Commands.Move(), copy(owner = newOwner)) } // Just for grouping @@ -165,7 +167,7 @@ class Cash : Contract { */ @Throws(InsufficientBalanceException::class) fun craftSpend(tx: PartialTransaction, amount: Amount, to: PublicKey, - cashStates: List>, onlyFromParties: Set? = null) { + cashStates: List>, onlyFromParties: Set? = null): List { // Discussion // // This code is analogous to the Wallet.send() set of methods in bitcoinj, and has the same general outline. @@ -229,7 +231,9 @@ class Cash : Contract { for (state in gathered) tx.addInputState(state.ref) for (state in outputs) tx.addOutputState(state) // What if we already have a move command with the right keys? Filter it out here or in platform code? - tx.addArg(WireCommand(Commands.Move(), keysUsed.toList())) + val keysList = keysUsed.toList() + tx.addArg(WireCommand(Commands.Move(), keysList)) + return keysList } } diff --git a/src/main/kotlin/contracts/CommercialPaper.kt b/src/main/kotlin/contracts/CommercialPaper.kt index 804de02784..fd74e8b79e 100644 --- a/src/main/kotlin/contracts/CommercialPaper.kt +++ b/src/main/kotlin/contracts/CommercialPaper.kt @@ -38,13 +38,14 @@ class CommercialPaper : Contract { data class State( val issuance: PartyReference, - val owner: PublicKey, + override val owner: PublicKey, val faceValue: Amount, val maturityDate: Instant - ) : ContractState { + ) : OwnableState { override val programRef = CP_PROGRAM_ID fun withoutOwner() = copy(owner = NullPublicKey) + override fun withNewOwner(newOwner: PublicKey) = Pair(Commands.Move(), copy(owner = newOwner)) } interface Commands : Command { diff --git a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt new file mode 100644 index 0000000000..2522aacf76 --- /dev/null +++ b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt @@ -0,0 +1,251 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package contracts.protocols + +import com.google.common.util.concurrent.ListenableFuture +import com.google.common.util.concurrent.SettableFuture +import contracts.Cash +import contracts.sumCashBy +import core.* +import core.messaging.MessagingSystem +import core.messaging.SingleMessageRecipient +import core.serialization.SerializeableWithKryo +import core.serialization.THREAD_LOCAL_KRYO +import core.serialization.deserialize +import core.serialization.registerDataClass +import core.utilities.continuations.* +import core.utilities.trace +import org.slf4j.LoggerFactory +import java.security.KeyPair +import java.security.PrivateKey +import java.security.PublicKey +import java.security.SecureRandom + +/** + * This asset trading protocol has two parties (B and S for buyer and seller) and the following steps: + * + * 1. B sends the [StateAndRef] pointing to what they want to sell to S, along with info about the price. + * 2. S sends to B a [SignedWireTransaction] that includes the state as input, S's cash as input, the state with the new + * owner key as output, and any change cash as output. It contains a single signature from S but isn't valid because + * it lacks a signature from B authorising movement of the asset. + * 3. B signs it and hands the now finalised SignedWireTransaction back to S. + * + * They both end the protocol being in posession of a validly signed contract. + * + * To get an implementation of this class, use the static [TwoPartyTradeProtocol.create] method. Then use either + * the [runBuyer] or [runSeller] methods, depending on which side of the trade your node is taking. These methods + * return a future which will complete once the trade is over and a fully signed transaction is available: you can + * either block your thread waiting for the protocol to complete by using [ListenableFuture.get] or more usefully, + * register a callback that will be invoked when the time comes. + * + * To see an example of how to use this class, look at the unit tests. + */ +abstract class TwoPartyTradeProtocol { + abstract fun runSeller( + net: MessagingSystem, + otherSide: SingleMessageRecipient, + assetToSell: StateAndRef, + price: Amount, + myKey: KeyPair, + partyKeyMap: Map, + timestamper: TimestamperService + ): ListenableFuture> + + abstract fun runBuyer( + net: MessagingSystem, + otherSide: SingleMessageRecipient, + acceptablePrice: Amount, + typeToSell: Class, + wallet: List>, + myKeys: Map, + timestamper: TimestamperService, + partyKeyMap: Map + ): ListenableFuture> + + companion object { + @JvmStatic fun create(): TwoPartyTradeProtocol { + return TwoPartyTradeProtocolImpl() + } + } +} + +private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() { + companion object { + val TRADE_TOPIC = "com.r3cev.protocols.trade" + fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong()) + } + + init { + THREAD_LOCAL_KRYO.get().registerDataClass() + } + + // This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and + // can be serialised. + class SellerInitialArgs( + val assetToSell: StateAndRef, + val price: Amount, + val myKeyPair: KeyPair + ) : SerializeableWithKryo + + // This wraps the things which the seller needs, but which might change whilst the continuation is suspended, + // e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args + // and the future that the code will fill out when done. + class SellerContext( + val timestamper: TimestamperService, + val partyKeyMap: Map, + val initialArgs: SellerInitialArgs?, + val resultFuture: SettableFuture> = SettableFuture.create() + ) + + // This object is serialised to the network and is the first protocol message the seller sends to the buyer. + class SellerTradeInfo( + val assetForSale: StateAndRef, + val price: Amount, + val primaryOwnerKey: PublicKey, + val sessionID: Long + ) : SerializeableWithKryo + + // The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled + // by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be + // trying to poke at is different to the one we saw at compile time, so we'd get ClassCastExceptions. All + // interaction with this class must be through either interfaces, or objects passed to and from the continuation + // by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about + // the protocol state machine framework. + class Seller : ProtocolStateMachine { + override fun run() { + val sessionID = makeSessionID() + val args = context().initialArgs!! + + val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID) + // Zero is a special session ID that is used to start a trade (i.e. before a session is started). + var (ctx2, offerMsg) = sendAndReceive(TRADE_TOPIC, 0, sessionID, hello) + logger().trace { "Received partially signed transaction" } + + val partialTx = offerMsg + partialTx.verifySignatures() + val wtx = partialTx.txBits.deserialize() + + requireThat { + "transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(args.myKeyPair.public) == args.price) + // There are all sorts of funny games a malicious secondary might play here, we should fix them: + // + // - This tx may attempt to send some assets we aren't intending to sell to the secondary, if + // we're reusing keys! So don't reuse keys! + // - This tx may not be valid according to the contracts of the input states, so we must resolve + // and fully audit the transaction chains to convince ourselves that it is actually valid. + // - This tx may include output states that impose odd conditions on the movement of the cash, + // once we implement state pairing. + // + // but the goal of this code is not to be fully secure, but rather, just to find good ways to + // express protocol state machines on top of the messaging layer. + } + + val ourSignature = args.myKeyPair.signWithECDSA(partialTx.txBits.bits) + val fullySigned: SignedWireTransaction = partialTx.copy(sigs = partialTx.sigs + ourSignature) + // We should run it through our full TransactionGroup of all transactions here. + fullySigned.verify() + val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper) + logger().trace { "Built finished transaction, sending back to secondary!" } + send(TRADE_TOPIC, sessionID, timestamped) + ctx2.resultFuture.set(Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap))) + } + } + + class UnacceptablePriceException(val givenPrice: Amount) : Exception() + class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() { + override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName" + } + + + class BuyerContext( + val acceptablePrice: Amount, + val typeToSell: Class, + val wallet: List>, + val myKeys: Map, + val timestamper: TimestamperService, + val partyKeyMap: Map, + val resultFuture: SettableFuture> = SettableFuture.create() + ) + + // The buyer's side of the protocol. See note above Seller to learn about the caveats here. + class Buyer : ProtocolStateMachine { + override fun run() { + // Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below, + // as the context object we're meant to use might change each time we suspend (e.g. due to VM restart). + val (stx, theirSessionID) = run { + // Wait for a trade request to come in. + val (ctx, tradeRequest) = receive(TRADE_TOPIC, 0) + val assetTypeName = tradeRequest.assetForSale.state.javaClass.name + + logger().trace { "Got trade request for a $assetTypeName" } + + // Check the start message for acceptability. + check(tradeRequest.sessionID > 0) + if (tradeRequest.price > ctx.acceptablePrice) + throw UnacceptablePriceException(tradeRequest.price) + if (!ctx.typeToSell.isInstance(tradeRequest.assetForSale.state)) + throw AssetMismatchException(ctx.typeToSell.name, assetTypeName) + + // TODO: Either look up the stateref here in our local db, or accept a long chain of states and + // validate them to audit the other side and ensure it actually owns the state we are being offered! + // For now, just assume validity! + + // Generate the shared transaction that both sides will sign, using the data we have. + val ptx = PartialTransaction() + // Add input and output states for the movement of cash. + val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.primaryOwnerKey, ctx.wallet) + // Add inputs/outputs/a command for the movement of the asset. + ptx.addInputState(tradeRequest.assetForSale.ref) + // Just pick some arbitrary public key for now (this provides poor privacy). + val (command, state) = tradeRequest.assetForSale.state.withNewOwner(ctx.myKeys.keys.first()) + ptx.addOutputState(state) + ptx.addArg(WireCommand(command, tradeRequest.assetForSale.state.owner)) + + for (k in cashSigningPubKeys) { + // TODO: This error case should be removed through the introduction of a Wallet class. + val priv = ctx.myKeys[k] ?: throw IllegalStateException("Coin in wallet with no known privkey") + ptx.signWith(KeyPair(k, priv)) + } + + val stx = ptx.toSignedTransaction(checkSufficientSignatures = false) + stx.verifySignatures() // Verifies that we generated a signed transaction correctly. + Pair(stx, tradeRequest.sessionID) + } + + // TODO: Could run verify() here to make sure the only signature missing is the primaries. + logger().trace { "Sending partially signed transaction to primary" } + // We'll just reuse the session ID the primary selected here for convenience. + val (ctx, fullySigned) = sendAndReceive(TRADE_TOPIC, theirSessionID, theirSessionID, stx) + logger().trace { "Got fully signed transaction, verifying ... "} + val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap) + logger().trace { "Fully signed transaction was valid. Trade complete! :-)" } + ctx.resultFuture.set(Pair(fullySigned, ltx)) + } + } + + override fun runSeller(net: MessagingSystem, otherSide: SingleMessageRecipient, assetToSell: StateAndRef, + price: Amount, myKey: KeyPair, partyKeyMap: Map, + timestamper: TimestamperService): ListenableFuture> { + val args = SellerInitialArgs(assetToSell, price, myKey) + val context = SellerContext(timestamper, partyKeyMap, args) + val logger = LoggerFactory.getLogger("$TRADE_TOPIC.primary") + loadContinuationClass(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger) + return context.resultFuture + } + + override fun runBuyer(net: MessagingSystem, otherSide: SingleMessageRecipient, acceptablePrice: Amount, + typeToSell: Class, wallet: List>, + myKeys: Map, timestamper: TimestamperService, + partyKeyMap: Map): ListenableFuture> { + val context = BuyerContext(acceptablePrice, typeToSell, wallet, myKeys, timestamper, partyKeyMap) + val logger = LoggerFactory.getLogger("$TRADE_TOPIC.secondary") + loadContinuationClass(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger) + return context.resultFuture + } +} \ No newline at end of file diff --git a/src/main/kotlin/core/Utils.kt b/src/main/kotlin/core/Utils.kt index 1ce443b211..ed77c5e44b 100644 --- a/src/main/kotlin/core/Utils.kt +++ b/src/main/kotlin/core/Utils.kt @@ -9,6 +9,8 @@ package core import com.google.common.io.BaseEncoding +import com.google.common.util.concurrent.SettableFuture +import org.slf4j.Logger import java.time.Duration import java.util.* @@ -38,3 +40,14 @@ val Int.days: Duration get() = Duration.ofDays(this.toLong()) val Int.hours: Duration get() = Duration.ofHours(this.toLong()) val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong()) val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong()) + +/** Executes the given block and sets the future to either the result, or any exception that was thrown. */ +fun SettableFuture.setFrom(logger: Logger? = null, block: () -> T): SettableFuture { + try { + set(block()) + } catch (e: Exception) { + logger?.error("Caught exception", e) + setException(e) + } + return this +} \ No newline at end of file diff --git a/src/main/kotlin/core/messaging/InMemoryNetwork.kt b/src/main/kotlin/core/messaging/InMemoryNetwork.kt index 45da29a228..a9f07638cb 100644 --- a/src/main/kotlin/core/messaging/InMemoryNetwork.kt +++ b/src/main/kotlin/core/messaging/InMemoryNetwork.kt @@ -89,7 +89,7 @@ public class InMemoryNetwork { * An instance can be obtained by creating a builder and then using the start method. */ inner class Node(private val manuallyPumped: Boolean): MessagingSystem { - inner class Handler(val executor: Executor?, val topic: String, val callback: (Message) -> Unit) : MessageHandlerRegistration + inner class Handler(val executor: Executor?, val topic: String, val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration @GuardedBy("this") protected val handlers: MutableList = ArrayList() @GuardedBy("this") @@ -101,7 +101,7 @@ public class InMemoryNetwork { } @Synchronized - override fun addMessageHandler(executor: Executor?, topic: String, callback: (Message) -> Unit): MessageHandlerRegistration { + override fun addMessageHandler(topic: String, executor: Executor?, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { check(running) return Handler(executor, topic, callback).apply { handlers.add(this) } } @@ -115,7 +115,7 @@ public class InMemoryNetwork { @Synchronized override fun send(message: Message, target: MessageRecipients) { check(running) - L.trace { "Sending $message to '$target'" } + L.trace { "Sending message of topic '${message.topic}' to '$target'" } when (target) { is InMemoryNodeHandle -> { val node = networkMap[target] ?: throw IllegalArgumentException("Unknown message recipient: $target") @@ -172,7 +172,7 @@ public class InMemoryNetwork { for (handler in deliverTo) { // Now deliver via the requested executor, or on this thread if no executor was provided at registration time. - (handler.executor ?: MoreExecutors.directExecutor()).execute { handler.callback(message) } + (handler.executor ?: MoreExecutors.directExecutor()).execute { handler.callback(message, handler) } } return true diff --git a/src/main/kotlin/core/messaging/Messaging.kt b/src/main/kotlin/core/messaging/Messaging.kt index 854a8f24d1..58e6021f19 100644 --- a/src/main/kotlin/core/messaging/Messaging.kt +++ b/src/main/kotlin/core/messaging/Messaging.kt @@ -9,6 +9,9 @@ package core.messaging import com.google.common.util.concurrent.ListenableFuture +import core.serialization.SerializeableWithKryo +import core.serialization.deserialize +import core.serialization.serialize import java.time.Duration import java.time.Instant import java.util.concurrent.Executor @@ -33,12 +36,14 @@ interface MessagingSystem { * If no executor is received then the callback will run on threads provided by the messaging system, and the * callback is expected to be thread safe as a result. * - * The returned object is an opaque handle that may be used to un-register handlers later with [addMessageHandler]. + * The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler]. + * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister + * itself and yet addMessageHandler hasn't returned the handle yet. * * If the callback throws an exception then the message is discarded and will not be retried, unless the exception * is a subclass of [RetryMessageLaterException], in which case the message will be queued and attempted later. */ - fun addMessageHandler(executor: Executor? = null, topic: String = "", callback: (Message) -> Unit): MessageHandlerRegistration + fun addMessageHandler(topic: String = "", executor: Executor? = null, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration /** * Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once @@ -68,6 +73,33 @@ interface MessagingSystem { fun createMessage(topic: String, data: ByteArray): Message } +/** + * Registers a handler for the given topic that runs the given callback with the message and then removes itself. This + * is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback doesn't + * take the registration object, unlike the callback to [MessagingSystem.addMessageHandler]. + */ +fun MessagingSystem.runOnNextMessage(topic: String = "", executor: Executor? = null, callback: (Message) -> Unit) { + addMessageHandler(topic, executor) { msg, reg -> + callback(msg) + removeMessageHandler(reg) + } +} + +fun MessagingSystem.send(topic: String, to: MessageRecipients, obj: SerializeableWithKryo) = send(createMessage(topic, obj.serialize()), to) + +/** + * Registers a handler for the given topic that runs the given callback with the message content deserialised to the + * given type, and then removes itself. + */ +inline fun MessagingSystem.runOnNextMessageWith(topic: String = "", + executor: Executor? = null, + noinline callback: (T) -> Unit) { + addMessageHandler(topic, executor) { msg, reg -> + callback(msg.data.deserialize()) + removeMessageHandler(reg) + } +} + /** * This class lets you start up a [MessagingSystem]. Its purpose is to stop you from getting access to the methods * on the messaging system interface until you have successfully started up the system. One of these objects should diff --git a/src/main/kotlin/core/utilities/Logging.kt b/src/main/kotlin/core/utilities/Logging.kt index 808a8ff63d..be32a84f38 100644 --- a/src/main/kotlin/core/utilities/Logging.kt +++ b/src/main/kotlin/core/utilities/Logging.kt @@ -65,10 +65,10 @@ class BriefLogFormatter : Formatter() { handlers[0].formatter = BriefLogFormatter() } - fun initVerbose() { + fun initVerbose(packageSpec: String = "") { init() - loggerRef.level = Level.ALL loggerRef.handlers[0].level = Level.ALL + Logger.getLogger(packageSpec).level = Level.ALL } } } diff --git a/src/main/kotlin/core/utilities/continuations/ContinuationsSupport.kt b/src/main/kotlin/core/utilities/continuations/ContinuationsSupport.kt new file mode 100644 index 0000000000..31052d28cf --- /dev/null +++ b/src/main/kotlin/core/utilities/continuations/ContinuationsSupport.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.utilities.continuations + +import org.apache.commons.javaflow.Continuation +import org.apache.commons.javaflow.ContinuationClassLoader + +/** + * A "continuation" is an object that represents a suspended execution of a function. They allow you to write code + * that suspends itself half way through, bundles up everything that was on the stack into a (potentially serialisable) + * object, and then be resumed from the exact same spot later. Continuations are not natively supported by the JVM + * but we can use the Apache JavaFlow library which implements them using bytecode rewriting. + * + * The primary benefit of using continuations is that state machine/protocol code that would otherwise be very + * convoluted and hard to read becomes very clear and straightforward. + * + * TODO: Document classloader interactions and gotchas here. + */ +inline fun loadContinuationClass(classLoader: ClassLoader): Continuation { + val klass = T::class.java + val url = klass.protectionDomain.codeSource.location + val cl = ContinuationClassLoader(arrayOf(url), classLoader) + val obj = cl.forceLoadClass(klass.name).newInstance() as Runnable + return Continuation.startSuspendedWith(obj) +} diff --git a/src/main/kotlin/core/utilities/continuations/ProtocolStateMachines.kt b/src/main/kotlin/core/utilities/continuations/ProtocolStateMachines.kt new file mode 100644 index 0000000000..e88d73bea2 --- /dev/null +++ b/src/main/kotlin/core/utilities/continuations/ProtocolStateMachines.kt @@ -0,0 +1,118 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.utilities.continuations + +import com.esotericsoftware.kryo.io.Output +import core.messaging.MessagingSystem +import core.messaging.SingleMessageRecipient +import core.messaging.runOnNextMessage +import core.serialization.SerializeableWithKryo +import core.serialization.createKryo +import core.serialization.deserialize +import core.serialization.serialize +import core.utilities.trace +import org.apache.commons.javaflow.Continuation +import org.slf4j.Logger +import java.io.ByteArrayOutputStream + +private val CONTINUATION_LOGGER = ThreadLocal() + +/** + * A convenience mixing interface that can be implemented by an object that will act as a continuation. + * + * A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface + * provides are pre-defined utility methods to ease implementation of such machines. + */ +@Suppress("UNCHECKED_CAST") +interface ProtocolStateMachine : Runnable { + fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE + fun logger(): Logger = CONTINUATION_LOGGER.get() +} + +@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST") +inline fun ProtocolStateMachine.send(topic: String, sessionID: Long, obj: S) = + Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE + +@Suppress("UNCHECKED_CAST") +inline fun ProtocolStateMachine.sendAndReceive(topic: String, sessionIDForSend: Long, + sessionIDForReceive: Long, + obj: SerializeableWithKryo) = + Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive, + obj, R::class.java)) as Pair + + +@Suppress("UNCHECKED_CAST") +inline fun receive(topic: String, sessionIDForReceive: Long) = + Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null, + R::class.java)) as Pair + +open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: SerializeableWithKryo?) { + class ExpectingResponse( + topic: String, + sessionIDForSend: Long, + sessionIDForReceive: Long, + obj: SerializeableWithKryo?, + val responseType: Class + ) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj) + + class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: SerializeableWithKryo?) : ContinuationResult(topic, sessionIDForSend, -1, obj) +} + +fun Continuation.iterateStateMachine(net: MessagingSystem, otherSide: SingleMessageRecipient, + transientContext: Any, continuationInput: Any?, logger: Logger): Continuation { + // This will resume execution of the run() function inside the continuation at the place it left off. + val oldLogger = CONTINUATION_LOGGER.get() + val nextState = try { + CONTINUATION_LOGGER.set(logger) + Continuation.continueWith(this, continuationInput) + } catch (t: Throwable) { + logger.error("Caught error whilst invoking protocol state machine", t) + throw t + } finally { + CONTINUATION_LOGGER.set(oldLogger) + } + // If continuation returns null, it's finished. + val req = nextState?.value() as? ContinuationResult ?: return this + + // Else, it wants us to do something: send, receive, or send-and-receive. Firstly, checkpoint it, so we can restart + // if something goes wrong. + val bytes = run { + val stream = ByteArrayOutputStream() + Output(stream).use { + createKryo().apply { + isRegistrationRequired = false + writeObject(it, nextState) + } + } + stream.toByteArray() + } + + if (req is ContinuationResult.ExpectingResponse<*>) { + val topic = "${req.topic}.${req.sessionIDForReceive}" + net.runOnNextMessage(topic) { netMsg -> + val obj = netMsg.data.deserialize(req.responseType) + logger.trace { "<- $topic : message of type ${obj.javaClass.name}" } + nextState.iterateStateMachine(net, otherSide, transientContext, Pair(transientContext, obj), logger) + } + } + // If an object to send was provided (not null), send it now. + req.obj?.let { + val topic = "${req.topic}.${req.sessionIDForSend}" + logger.trace { "-> $topic : message of type ${it.javaClass.name}" } + net.send(net.createMessage(topic, it.serialize()), otherSide) + } + if (req is ContinuationResult.NotExpectingResponse) { + // We sent a message, but won't get a response, so we must re-enter the continuation to let it keep going. + return nextState.iterateStateMachine(net, otherSide, transientContext, transientContext, logger) + } else { + return nextState + } +} diff --git a/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt b/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt index 7a4a2d3ad1..971b79c78c 100644 --- a/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt +++ b/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt @@ -18,14 +18,10 @@ import kotlin.test.assertEquals import kotlin.test.assertFails import kotlin.test.assertTrue -class InMemoryMessagingTests { +open class TestWithInMemoryNetwork { val nodes: MutableMap = HashMap() lateinit var network: InMemoryNetwork - init { - // BriefLogFormatter.initVerbose() - } - fun makeNode(): Pair { // The manuallyPumped = true bit means that we must call the pump method on the system in order to val (address, builder) = network.createNode(manuallyPumped = true) @@ -34,29 +30,32 @@ class InMemoryMessagingTests { return Pair(address, node) } - fun pumpAll() { - nodes.values.forEach { it.pump(false) } - } - - // Utilities to help define messaging rounds. - fun roundWithPumpings(times: Int, body: () -> Unit) { - body() - repeat(times) { pumpAll() } - } - - fun round(body: () -> Unit) = roundWithPumpings(1, body) - @Before - fun before() { + fun setupNetwork() { network = InMemoryNetwork() nodes.clear() } @After - fun after() { + fun stopNetwork() { network.stop() } + fun pumpAll() = nodes.values.map { it.pump(false) } + + // Keep calling "pump" in rounds until every node in the network reports that it had nothing to do. + fun runNetwork(body: () -> T): T { + val result = body() + while (pumpAll().any { it }) {} + return result + } +} + +class InMemoryMessagingTests : TestWithInMemoryNetwork() { + init { + // BriefLogFormatter.initVerbose() + } + @Test fun topicStringValidation() { TopicStringValidator.check("this.is.ok") @@ -82,19 +81,19 @@ class InMemoryMessagingTests { var finalDelivery: Message? = null with(node2) { - addMessageHandler { - send(it, addr3) + addMessageHandler { msg, registration -> + send(msg, addr3) } } with(node3) { - addMessageHandler { - finalDelivery = it + addMessageHandler { msg, registration -> + finalDelivery = msg } } // Node 1 sends a message and it should end up in finalDelivery, after we pump each node. - roundWithPumpings(2) { + runNetwork { node1.send(node1.createMessage("test.topic", bits), addr2) } @@ -110,8 +109,8 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var counter = 0 - listOf(node1, node2, node3).forEach { it.addMessageHandler { counter++ } } - round { + listOf(node1, node2, node3).forEach { it.addMessageHandler { msg, registration -> counter++ } } + runNetwork { node1.send(node2.createMessage("test.topic", bits), network.entireNetwork) } assertEquals(3, counter) diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt new file mode 100644 index 0000000000..d6ada06cd8 --- /dev/null +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -0,0 +1,73 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.messaging + +import contracts.Cash +import contracts.CommercialPaper +import contracts.protocols.TwoPartyTradeProtocol +import core.* +import core.testutils.* +import org.junit.Test +import java.util.logging.Formatter +import java.util.logging.Level +import java.util.logging.LogRecord +import java.util.logging.Logger +import kotlin.test.assertEquals + +/** + * In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do + * it on the ledger atomically. Therefore they must work together to build a transaction. + * + * We assume that Alessia and Boris already found each other via some market, and have agreed the details already. + */ +class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { + init { + Logger.getLogger("").handlers[0].level = Level.ALL + Logger.getLogger("").handlers[0].formatter = object : Formatter() { + override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n" + } + Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL + } + + @Test + fun cashForCP() { + val (addr1, node1) = makeNode() + val (addr2, node2) = makeNode() + + val tp = TwoPartyTradeProtocol.create() + + transactionGroupFor { + // Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob. + roots { + transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper") + transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1") + transaction(300.DOLLARS.CASH `owned by` BOB label "bob cash2") + } + + val bobsWallet = listOf>(lookup("bob cash1"), lookup("bob cash2")) + val (aliceFuture, bobFuture) = runNetwork { + Pair( + tp.runSeller(node1, addr2, lookup("alice's paper"), 1000.DOLLARS, ALICE_KEY, + TEST_KEYS_TO_CORP_MAP, DUMMY_TIMESTAMPER), + tp.runBuyer(node2, addr1, 1000.DOLLARS, CommercialPaper.State::class.java, bobsWallet, + mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP) + ) + } + + val aliceResult: Pair = aliceFuture.get() + val bobResult: Pair = bobFuture.get() + + assertEquals(aliceResult, bobResult) + + txns.add(aliceResult.second) + + verify() + } + } +} diff --git a/src/test/kotlin/core/testutils/TestUtils.kt b/src/test/kotlin/core/testutils/TestUtils.kt index 82fc527a7c..f900d7c43f 100644 --- a/src/test/kotlin/core/testutils/TestUtils.kt +++ b/src/test/kotlin/core/testutils/TestUtils.kt @@ -220,7 +220,7 @@ class TransactionGroupDSL(private val stateType: Class) { private val inStates = ArrayList() fun input(label: String) { - inStates.add(labelToRefs[label] ?: throw IllegalArgumentException("Unknown label \"$label\"")) + inStates.add(label.outputRef) } @@ -235,6 +235,9 @@ class TransactionGroupDSL(private val stateType: Class) { } val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found") + val String.outputRef: ContractStateRef get() = labelToRefs[this] ?: throw IllegalArgumentException("Unknown label \"$this\"") + + fun lookup(label: String) = StateAndRef(label.output as C, label.outputRef) private inner class InternalLedgerTransactionDSL : LedgerTransactionDSL() { fun finaliseAndInsertLabels(time: Instant): LedgerTransaction { @@ -268,6 +271,7 @@ class TransactionGroupDSL(private val stateType: Class) { val label = state.label!! labelToRefs[label] = ContractStateRef(ltx.hash, index) outputsToLabels[state.state] = label + labelToOutputs[label] = state.state as T } rootTxns.add(ltx) }