mirror of
https://github.com/corda/corda.git
synced 2025-04-19 08:36:39 +00:00
Trading protocol work in progress
This commit is contained in:
parent
fed0ae5629
commit
0ca47156bc
20
build.gradle
20
build.gradle
@ -24,6 +24,9 @@ buildscript {
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
maven {
|
||||
url 'http://oss.sonatype.org/content/repositories/snapshots'
|
||||
}
|
||||
jcenter()
|
||||
}
|
||||
|
||||
@ -39,6 +42,23 @@ dependencies {
|
||||
// Logging
|
||||
compile "org.slf4j:slf4j-jdk14:1.7.13"
|
||||
|
||||
// For the continuations in the state machine tests. Note: JavaFlow is old and unmaintained but still seems to work
|
||||
// just fine, once the patch here is applied to update it to a Java8 compatible asm:
|
||||
//
|
||||
// https://github.com/playframework/play1/commit/e0e28e6780a48c000e7ed536962f1f284cef9437
|
||||
//
|
||||
// Obviously using this year-old upload to Maven Central by the Maven Play Plugin team is a short term hack for
|
||||
// experimenting. Using this for real would mean forking JavaFlow and taking over maintenance (luckily it's small
|
||||
// and Java is stable, so this is unlikely to be a big burden). We have to manually force an in-place upgrade to
|
||||
// asm 5.0.3 here (javaflow wants 5.0.2) in order to avoid version conflicts. This is also something that should be
|
||||
// fixed in any fork. Sadly, even Jigsaw doesn't solve this problem out of the box.
|
||||
compile "com.google.code.maven-play-plugin.org.apache.commons:commons-javaflow:1590792-patched-play-1.3.0"
|
||||
compile "org.ow2.asm:asm:5.0.3"
|
||||
compile "org.ow2.asm:asm-analysis:5.0.3"
|
||||
compile "org.ow2.asm:asm-tree:5.0.3"
|
||||
compile "org.ow2.asm:asm-commons:5.0.3"
|
||||
compile "org.ow2.asm:asm-util:5.0.3"
|
||||
|
||||
// For visualisation
|
||||
compile "org.graphstream:gs-core:1.3"
|
||||
compile "org.graphstream:gs-ui:1.3"
|
||||
|
@ -56,10 +56,12 @@ class Cash : Contract {
|
||||
val amount: Amount,
|
||||
|
||||
/** There must be a MoveCommand signed by this key to claim the amount */
|
||||
val owner: PublicKey
|
||||
) : ContractState {
|
||||
override val owner: PublicKey
|
||||
) : OwnableState {
|
||||
override val programRef = CASH_PROGRAM_ID
|
||||
override fun toString() = "Cash($amount at $deposit owned by $owner)"
|
||||
|
||||
override fun withNewOwner(newOwner: PublicKey) = Pair(Commands.Move(), copy(owner = newOwner))
|
||||
}
|
||||
|
||||
// Just for grouping
|
||||
@ -165,7 +167,7 @@ class Cash : Contract {
|
||||
*/
|
||||
@Throws(InsufficientBalanceException::class)
|
||||
fun craftSpend(tx: PartialTransaction, amount: Amount, to: PublicKey,
|
||||
cashStates: List<StateAndRef<State>>, onlyFromParties: Set<Party>? = null) {
|
||||
cashStates: List<StateAndRef<Cash.State>>, onlyFromParties: Set<Party>? = null): List<PublicKey> {
|
||||
// Discussion
|
||||
//
|
||||
// This code is analogous to the Wallet.send() set of methods in bitcoinj, and has the same general outline.
|
||||
@ -229,7 +231,9 @@ class Cash : Contract {
|
||||
for (state in gathered) tx.addInputState(state.ref)
|
||||
for (state in outputs) tx.addOutputState(state)
|
||||
// What if we already have a move command with the right keys? Filter it out here or in platform code?
|
||||
tx.addArg(WireCommand(Commands.Move(), keysUsed.toList()))
|
||||
val keysList = keysUsed.toList()
|
||||
tx.addArg(WireCommand(Commands.Move(), keysList))
|
||||
return keysList
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,13 +38,14 @@ class CommercialPaper : Contract {
|
||||
|
||||
data class State(
|
||||
val issuance: PartyReference,
|
||||
val owner: PublicKey,
|
||||
override val owner: PublicKey,
|
||||
val faceValue: Amount,
|
||||
val maturityDate: Instant
|
||||
) : ContractState {
|
||||
) : OwnableState {
|
||||
override val programRef = CP_PROGRAM_ID
|
||||
|
||||
fun withoutOwner() = copy(owner = NullPublicKey)
|
||||
override fun withNewOwner(newOwner: PublicKey) = Pair(Commands.Move(), copy(owner = newOwner))
|
||||
}
|
||||
|
||||
interface Commands : Command {
|
||||
|
251
src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt
Normal file
251
src/main/kotlin/contracts/protocols/TwoPartyTradeProtocolImpl.kt
Normal file
@ -0,0 +1,251 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package contracts.protocols
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import contracts.Cash
|
||||
import contracts.sumCashBy
|
||||
import core.*
|
||||
import core.messaging.MessagingSystem
|
||||
import core.messaging.SingleMessageRecipient
|
||||
import core.serialization.SerializeableWithKryo
|
||||
import core.serialization.THREAD_LOCAL_KRYO
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.registerDataClass
|
||||
import core.utilities.continuations.*
|
||||
import core.utilities.trace
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.security.KeyPair
|
||||
import java.security.PrivateKey
|
||||
import java.security.PublicKey
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* This asset trading protocol has two parties (B and S for buyer and seller) and the following steps:
|
||||
*
|
||||
* 1. B sends the [StateAndRef] pointing to what they want to sell to S, along with info about the price.
|
||||
* 2. S sends to B a [SignedWireTransaction] that includes the state as input, S's cash as input, the state with the new
|
||||
* owner key as output, and any change cash as output. It contains a single signature from S but isn't valid because
|
||||
* it lacks a signature from B authorising movement of the asset.
|
||||
* 3. B signs it and hands the now finalised SignedWireTransaction back to S.
|
||||
*
|
||||
* They both end the protocol being in posession of a validly signed contract.
|
||||
*
|
||||
* To get an implementation of this class, use the static [TwoPartyTradeProtocol.create] method. Then use either
|
||||
* the [runBuyer] or [runSeller] methods, depending on which side of the trade your node is taking. These methods
|
||||
* return a future which will complete once the trade is over and a fully signed transaction is available: you can
|
||||
* either block your thread waiting for the protocol to complete by using [ListenableFuture.get] or more usefully,
|
||||
* register a callback that will be invoked when the time comes.
|
||||
*
|
||||
* To see an example of how to use this class, look at the unit tests.
|
||||
*/
|
||||
abstract class TwoPartyTradeProtocol {
|
||||
abstract fun runSeller(
|
||||
net: MessagingSystem,
|
||||
otherSide: SingleMessageRecipient,
|
||||
assetToSell: StateAndRef<OwnableState>,
|
||||
price: Amount,
|
||||
myKey: KeyPair,
|
||||
partyKeyMap: Map<PublicKey, Party>,
|
||||
timestamper: TimestamperService
|
||||
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
|
||||
abstract fun runBuyer(
|
||||
net: MessagingSystem,
|
||||
otherSide: SingleMessageRecipient,
|
||||
acceptablePrice: Amount,
|
||||
typeToSell: Class<out OwnableState>,
|
||||
wallet: List<StateAndRef<Cash.State>>,
|
||||
myKeys: Map<PublicKey, PrivateKey>,
|
||||
timestamper: TimestamperService,
|
||||
partyKeyMap: Map<PublicKey, Party>
|
||||
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
|
||||
companion object {
|
||||
@JvmStatic fun create(): TwoPartyTradeProtocol {
|
||||
return TwoPartyTradeProtocolImpl()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
companion object {
|
||||
val TRADE_TOPIC = "com.r3cev.protocols.trade"
|
||||
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
|
||||
}
|
||||
|
||||
init {
|
||||
THREAD_LOCAL_KRYO.get().registerDataClass<TwoPartyTradeProtocolImpl.SellerTradeInfo>()
|
||||
}
|
||||
|
||||
// This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and
|
||||
// can be serialised.
|
||||
class SellerInitialArgs(
|
||||
val assetToSell: StateAndRef<OwnableState>,
|
||||
val price: Amount,
|
||||
val myKeyPair: KeyPair
|
||||
) : SerializeableWithKryo
|
||||
|
||||
// This wraps the things which the seller needs, but which might change whilst the continuation is suspended,
|
||||
// e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args
|
||||
// and the future that the code will fill out when done.
|
||||
class SellerContext(
|
||||
val timestamper: TimestamperService,
|
||||
val partyKeyMap: Map<PublicKey, Party>,
|
||||
val initialArgs: SellerInitialArgs?,
|
||||
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
|
||||
)
|
||||
|
||||
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
|
||||
class SellerTradeInfo(
|
||||
val assetForSale: StateAndRef<OwnableState>,
|
||||
val price: Amount,
|
||||
val primaryOwnerKey: PublicKey,
|
||||
val sessionID: Long
|
||||
) : SerializeableWithKryo
|
||||
|
||||
// The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled
|
||||
// by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be
|
||||
// trying to poke at is different to the one we saw at compile time, so we'd get ClassCastExceptions. All
|
||||
// interaction with this class must be through either interfaces, or objects passed to and from the continuation
|
||||
// by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about
|
||||
// the protocol state machine framework.
|
||||
class Seller : ProtocolStateMachine<SellerContext> {
|
||||
override fun run() {
|
||||
val sessionID = makeSessionID()
|
||||
val args = context().initialArgs!!
|
||||
|
||||
val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID)
|
||||
// Zero is a special session ID that is used to start a trade (i.e. before a session is started).
|
||||
var (ctx2, offerMsg) = sendAndReceive<SignedWireTransaction, SellerContext>(TRADE_TOPIC, 0, sessionID, hello)
|
||||
logger().trace { "Received partially signed transaction" }
|
||||
|
||||
val partialTx = offerMsg
|
||||
partialTx.verifySignatures()
|
||||
val wtx = partialTx.txBits.deserialize<WireTransaction>()
|
||||
|
||||
requireThat {
|
||||
"transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(args.myKeyPair.public) == args.price)
|
||||
// There are all sorts of funny games a malicious secondary might play here, we should fix them:
|
||||
//
|
||||
// - This tx may attempt to send some assets we aren't intending to sell to the secondary, if
|
||||
// we're reusing keys! So don't reuse keys!
|
||||
// - This tx may not be valid according to the contracts of the input states, so we must resolve
|
||||
// and fully audit the transaction chains to convince ourselves that it is actually valid.
|
||||
// - This tx may include output states that impose odd conditions on the movement of the cash,
|
||||
// once we implement state pairing.
|
||||
//
|
||||
// but the goal of this code is not to be fully secure, but rather, just to find good ways to
|
||||
// express protocol state machines on top of the messaging layer.
|
||||
}
|
||||
|
||||
val ourSignature = args.myKeyPair.signWithECDSA(partialTx.txBits.bits)
|
||||
val fullySigned: SignedWireTransaction = partialTx.copy(sigs = partialTx.sigs + ourSignature)
|
||||
// We should run it through our full TransactionGroup of all transactions here.
|
||||
fullySigned.verify()
|
||||
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper)
|
||||
logger().trace { "Built finished transaction, sending back to secondary!" }
|
||||
send(TRADE_TOPIC, sessionID, timestamped)
|
||||
ctx2.resultFuture.set(Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap)))
|
||||
}
|
||||
}
|
||||
|
||||
class UnacceptablePriceException(val givenPrice: Amount) : Exception()
|
||||
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
|
||||
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
|
||||
}
|
||||
|
||||
|
||||
class BuyerContext(
|
||||
val acceptablePrice: Amount,
|
||||
val typeToSell: Class<out OwnableState>,
|
||||
val wallet: List<StateAndRef<Cash.State>>,
|
||||
val myKeys: Map<PublicKey, PrivateKey>,
|
||||
val timestamper: TimestamperService,
|
||||
val partyKeyMap: Map<PublicKey, Party>,
|
||||
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
|
||||
)
|
||||
|
||||
// The buyer's side of the protocol. See note above Seller to learn about the caveats here.
|
||||
class Buyer : ProtocolStateMachine<BuyerContext> {
|
||||
override fun run() {
|
||||
// Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below,
|
||||
// as the context object we're meant to use might change each time we suspend (e.g. due to VM restart).
|
||||
val (stx, theirSessionID) = run {
|
||||
// Wait for a trade request to come in.
|
||||
val (ctx, tradeRequest) = receive<SellerTradeInfo, BuyerContext>(TRADE_TOPIC, 0)
|
||||
val assetTypeName = tradeRequest.assetForSale.state.javaClass.name
|
||||
|
||||
logger().trace { "Got trade request for a $assetTypeName" }
|
||||
|
||||
// Check the start message for acceptability.
|
||||
check(tradeRequest.sessionID > 0)
|
||||
if (tradeRequest.price > ctx.acceptablePrice)
|
||||
throw UnacceptablePriceException(tradeRequest.price)
|
||||
if (!ctx.typeToSell.isInstance(tradeRequest.assetForSale.state))
|
||||
throw AssetMismatchException(ctx.typeToSell.name, assetTypeName)
|
||||
|
||||
// TODO: Either look up the stateref here in our local db, or accept a long chain of states and
|
||||
// validate them to audit the other side and ensure it actually owns the state we are being offered!
|
||||
// For now, just assume validity!
|
||||
|
||||
// Generate the shared transaction that both sides will sign, using the data we have.
|
||||
val ptx = PartialTransaction()
|
||||
// Add input and output states for the movement of cash.
|
||||
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.primaryOwnerKey, ctx.wallet)
|
||||
// Add inputs/outputs/a command for the movement of the asset.
|
||||
ptx.addInputState(tradeRequest.assetForSale.ref)
|
||||
// Just pick some arbitrary public key for now (this provides poor privacy).
|
||||
val (command, state) = tradeRequest.assetForSale.state.withNewOwner(ctx.myKeys.keys.first())
|
||||
ptx.addOutputState(state)
|
||||
ptx.addArg(WireCommand(command, tradeRequest.assetForSale.state.owner))
|
||||
|
||||
for (k in cashSigningPubKeys) {
|
||||
// TODO: This error case should be removed through the introduction of a Wallet class.
|
||||
val priv = ctx.myKeys[k] ?: throw IllegalStateException("Coin in wallet with no known privkey")
|
||||
ptx.signWith(KeyPair(k, priv))
|
||||
}
|
||||
|
||||
val stx = ptx.toSignedTransaction(checkSufficientSignatures = false)
|
||||
stx.verifySignatures() // Verifies that we generated a signed transaction correctly.
|
||||
Pair(stx, tradeRequest.sessionID)
|
||||
}
|
||||
|
||||
// TODO: Could run verify() here to make sure the only signature missing is the primaries.
|
||||
logger().trace { "Sending partially signed transaction to primary" }
|
||||
// We'll just reuse the session ID the primary selected here for convenience.
|
||||
val (ctx, fullySigned) = sendAndReceive<TimestampedWireTransaction, BuyerContext>(TRADE_TOPIC, theirSessionID, theirSessionID, stx)
|
||||
logger().trace { "Got fully signed transaction, verifying ... "}
|
||||
val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap)
|
||||
logger().trace { "Fully signed transaction was valid. Trade complete! :-)" }
|
||||
ctx.resultFuture.set(Pair(fullySigned, ltx))
|
||||
}
|
||||
}
|
||||
|
||||
override fun runSeller(net: MessagingSystem, otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
|
||||
price: Amount, myKey: KeyPair, partyKeyMap: Map<PublicKey, Party>,
|
||||
timestamper: TimestamperService): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
val args = SellerInitialArgs(assetToSell, price, myKey)
|
||||
val context = SellerContext(timestamper, partyKeyMap, args)
|
||||
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.primary")
|
||||
loadContinuationClass<Seller>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
|
||||
return context.resultFuture
|
||||
}
|
||||
|
||||
override fun runBuyer(net: MessagingSystem, otherSide: SingleMessageRecipient, acceptablePrice: Amount,
|
||||
typeToSell: Class<out OwnableState>, wallet: List<StateAndRef<Cash.State>>,
|
||||
myKeys: Map<PublicKey, PrivateKey>, timestamper: TimestamperService,
|
||||
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
val context = BuyerContext(acceptablePrice, typeToSell, wallet, myKeys, timestamper, partyKeyMap)
|
||||
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.secondary")
|
||||
loadContinuationClass<Buyer>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
|
||||
return context.resultFuture
|
||||
}
|
||||
}
|
@ -9,6 +9,8 @@
|
||||
package core
|
||||
|
||||
import com.google.common.io.BaseEncoding
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import org.slf4j.Logger
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
|
||||
@ -38,3 +40,14 @@ val Int.days: Duration get() = Duration.ofDays(this.toLong())
|
||||
val Int.hours: Duration get() = Duration.ofHours(this.toLong())
|
||||
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
|
||||
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
|
||||
|
||||
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
|
||||
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
|
||||
try {
|
||||
set(block())
|
||||
} catch (e: Exception) {
|
||||
logger?.error("Caught exception", e)
|
||||
setException(e)
|
||||
}
|
||||
return this
|
||||
}
|
@ -89,7 +89,7 @@ public class InMemoryNetwork {
|
||||
* An instance can be obtained by creating a builder and then using the start method.
|
||||
*/
|
||||
inner class Node(private val manuallyPumped: Boolean): MessagingSystem {
|
||||
inner class Handler(val executor: Executor?, val topic: String, val callback: (Message) -> Unit) : MessageHandlerRegistration
|
||||
inner class Handler(val executor: Executor?, val topic: String, val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
|
||||
@GuardedBy("this")
|
||||
protected val handlers: MutableList<Handler> = ArrayList()
|
||||
@GuardedBy("this")
|
||||
@ -101,7 +101,7 @@ public class InMemoryNetwork {
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
override fun addMessageHandler(executor: Executor?, topic: String, callback: (Message) -> Unit): MessageHandlerRegistration {
|
||||
override fun addMessageHandler(topic: String, executor: Executor?, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
|
||||
check(running)
|
||||
return Handler(executor, topic, callback).apply { handlers.add(this) }
|
||||
}
|
||||
@ -115,7 +115,7 @@ public class InMemoryNetwork {
|
||||
@Synchronized
|
||||
override fun send(message: Message, target: MessageRecipients) {
|
||||
check(running)
|
||||
L.trace { "Sending $message to '$target'" }
|
||||
L.trace { "Sending message of topic '${message.topic}' to '$target'" }
|
||||
when (target) {
|
||||
is InMemoryNodeHandle -> {
|
||||
val node = networkMap[target] ?: throw IllegalArgumentException("Unknown message recipient: $target")
|
||||
@ -172,7 +172,7 @@ public class InMemoryNetwork {
|
||||
|
||||
for (handler in deliverTo) {
|
||||
// Now deliver via the requested executor, or on this thread if no executor was provided at registration time.
|
||||
(handler.executor ?: MoreExecutors.directExecutor()).execute { handler.callback(message) }
|
||||
(handler.executor ?: MoreExecutors.directExecutor()).execute { handler.callback(message, handler) }
|
||||
}
|
||||
|
||||
return true
|
||||
|
@ -9,6 +9,9 @@
|
||||
package core.messaging
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import core.serialization.SerializeableWithKryo
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.Executor
|
||||
@ -33,12 +36,14 @@ interface MessagingSystem {
|
||||
* If no executor is received then the callback will run on threads provided by the messaging system, and the
|
||||
* callback is expected to be thread safe as a result.
|
||||
*
|
||||
* The returned object is an opaque handle that may be used to un-register handlers later with [addMessageHandler].
|
||||
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
|
||||
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
|
||||
* itself and yet addMessageHandler hasn't returned the handle yet.
|
||||
*
|
||||
* If the callback throws an exception then the message is discarded and will not be retried, unless the exception
|
||||
* is a subclass of [RetryMessageLaterException], in which case the message will be queued and attempted later.
|
||||
*/
|
||||
fun addMessageHandler(executor: Executor? = null, topic: String = "", callback: (Message) -> Unit): MessageHandlerRegistration
|
||||
fun addMessageHandler(topic: String = "", executor: Executor? = null, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
|
||||
|
||||
/**
|
||||
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
|
||||
@ -68,6 +73,33 @@ interface MessagingSystem {
|
||||
fun createMessage(topic: String, data: ByteArray): Message
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a handler for the given topic that runs the given callback with the message and then removes itself. This
|
||||
* is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback doesn't
|
||||
* take the registration object, unlike the callback to [MessagingSystem.addMessageHandler].
|
||||
*/
|
||||
fun MessagingSystem.runOnNextMessage(topic: String = "", executor: Executor? = null, callback: (Message) -> Unit) {
|
||||
addMessageHandler(topic, executor) { msg, reg ->
|
||||
callback(msg)
|
||||
removeMessageHandler(reg)
|
||||
}
|
||||
}
|
||||
|
||||
fun MessagingSystem.send(topic: String, to: MessageRecipients, obj: SerializeableWithKryo) = send(createMessage(topic, obj.serialize()), to)
|
||||
|
||||
/**
|
||||
* Registers a handler for the given topic that runs the given callback with the message content deserialised to the
|
||||
* given type, and then removes itself.
|
||||
*/
|
||||
inline fun <reified T : SerializeableWithKryo> MessagingSystem.runOnNextMessageWith(topic: String = "",
|
||||
executor: Executor? = null,
|
||||
noinline callback: (T) -> Unit) {
|
||||
addMessageHandler(topic, executor) { msg, reg ->
|
||||
callback(msg.data.deserialize<T>())
|
||||
removeMessageHandler(reg)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This class lets you start up a [MessagingSystem]. Its purpose is to stop you from getting access to the methods
|
||||
* on the messaging system interface until you have successfully started up the system. One of these objects should
|
||||
|
@ -65,10 +65,10 @@ class BriefLogFormatter : Formatter() {
|
||||
handlers[0].formatter = BriefLogFormatter()
|
||||
}
|
||||
|
||||
fun initVerbose() {
|
||||
fun initVerbose(packageSpec: String = "") {
|
||||
init()
|
||||
loggerRef.level = Level.ALL
|
||||
loggerRef.handlers[0].level = Level.ALL
|
||||
Logger.getLogger(packageSpec).level = Level.ALL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.utilities.continuations
|
||||
|
||||
import org.apache.commons.javaflow.Continuation
|
||||
import org.apache.commons.javaflow.ContinuationClassLoader
|
||||
|
||||
/**
|
||||
* A "continuation" is an object that represents a suspended execution of a function. They allow you to write code
|
||||
* that suspends itself half way through, bundles up everything that was on the stack into a (potentially serialisable)
|
||||
* object, and then be resumed from the exact same spot later. Continuations are not natively supported by the JVM
|
||||
* but we can use the Apache JavaFlow library which implements them using bytecode rewriting.
|
||||
*
|
||||
* The primary benefit of using continuations is that state machine/protocol code that would otherwise be very
|
||||
* convoluted and hard to read becomes very clear and straightforward.
|
||||
*
|
||||
* TODO: Document classloader interactions and gotchas here.
|
||||
*/
|
||||
inline fun <reified T : Runnable> loadContinuationClass(classLoader: ClassLoader): Continuation {
|
||||
val klass = T::class.java
|
||||
val url = klass.protectionDomain.codeSource.location
|
||||
val cl = ContinuationClassLoader(arrayOf(url), classLoader)
|
||||
val obj = cl.forceLoadClass(klass.name).newInstance() as Runnable
|
||||
return Continuation.startSuspendedWith(obj)
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.utilities.continuations
|
||||
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import core.messaging.MessagingSystem
|
||||
import core.messaging.SingleMessageRecipient
|
||||
import core.messaging.runOnNextMessage
|
||||
import core.serialization.SerializeableWithKryo
|
||||
import core.serialization.createKryo
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import core.utilities.trace
|
||||
import org.apache.commons.javaflow.Continuation
|
||||
import org.slf4j.Logger
|
||||
import java.io.ByteArrayOutputStream
|
||||
|
||||
private val CONTINUATION_LOGGER = ThreadLocal<Logger>()
|
||||
|
||||
/**
|
||||
* A convenience mixing interface that can be implemented by an object that will act as a continuation.
|
||||
*
|
||||
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
|
||||
* provides are pre-defined utility methods to ease implementation of such machines.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
interface ProtocolStateMachine<CONTEXT_TYPE : Any> : Runnable {
|
||||
fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
|
||||
fun logger(): Logger = CONTINUATION_LOGGER.get()
|
||||
}
|
||||
|
||||
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
|
||||
inline fun <S : SerializeableWithKryo,
|
||||
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.send(topic: String, sessionID: Long, obj: S) =
|
||||
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : SerializeableWithKryo,
|
||||
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.sendAndReceive(topic: String, sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: SerializeableWithKryo) =
|
||||
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
|
||||
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : SerializeableWithKryo, CONTEXT_TYPE : Any> receive(topic: String, sessionIDForReceive: Long) =
|
||||
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
|
||||
R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
|
||||
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: SerializeableWithKryo?) {
|
||||
class ExpectingResponse<R : SerializeableWithKryo>(
|
||||
topic: String,
|
||||
sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: SerializeableWithKryo?,
|
||||
val responseType: Class<R>
|
||||
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
|
||||
|
||||
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: SerializeableWithKryo?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
|
||||
}
|
||||
|
||||
fun Continuation.iterateStateMachine(net: MessagingSystem, otherSide: SingleMessageRecipient,
|
||||
transientContext: Any, continuationInput: Any?, logger: Logger): Continuation {
|
||||
// This will resume execution of the run() function inside the continuation at the place it left off.
|
||||
val oldLogger = CONTINUATION_LOGGER.get()
|
||||
val nextState = try {
|
||||
CONTINUATION_LOGGER.set(logger)
|
||||
Continuation.continueWith(this, continuationInput)
|
||||
} catch (t: Throwable) {
|
||||
logger.error("Caught error whilst invoking protocol state machine", t)
|
||||
throw t
|
||||
} finally {
|
||||
CONTINUATION_LOGGER.set(oldLogger)
|
||||
}
|
||||
// If continuation returns null, it's finished.
|
||||
val req = nextState?.value() as? ContinuationResult ?: return this
|
||||
|
||||
// Else, it wants us to do something: send, receive, or send-and-receive. Firstly, checkpoint it, so we can restart
|
||||
// if something goes wrong.
|
||||
val bytes = run {
|
||||
val stream = ByteArrayOutputStream()
|
||||
Output(stream).use {
|
||||
createKryo().apply {
|
||||
isRegistrationRequired = false
|
||||
writeObject(it, nextState)
|
||||
}
|
||||
}
|
||||
stream.toByteArray()
|
||||
}
|
||||
|
||||
if (req is ContinuationResult.ExpectingResponse<*>) {
|
||||
val topic = "${req.topic}.${req.sessionIDForReceive}"
|
||||
net.runOnNextMessage(topic) { netMsg ->
|
||||
val obj = netMsg.data.deserialize(req.responseType)
|
||||
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
|
||||
nextState.iterateStateMachine(net, otherSide, transientContext, Pair(transientContext, obj), logger)
|
||||
}
|
||||
}
|
||||
// If an object to send was provided (not null), send it now.
|
||||
req.obj?.let {
|
||||
val topic = "${req.topic}.${req.sessionIDForSend}"
|
||||
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
|
||||
net.send(net.createMessage(topic, it.serialize()), otherSide)
|
||||
}
|
||||
if (req is ContinuationResult.NotExpectingResponse) {
|
||||
// We sent a message, but won't get a response, so we must re-enter the continuation to let it keep going.
|
||||
return nextState.iterateStateMachine(net, otherSide, transientContext, transientContext, logger)
|
||||
} else {
|
||||
return nextState
|
||||
}
|
||||
}
|
@ -18,14 +18,10 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFails
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class InMemoryMessagingTests {
|
||||
open class TestWithInMemoryNetwork {
|
||||
val nodes: MutableMap<SingleMessageRecipient, InMemoryNetwork.Node> = HashMap()
|
||||
lateinit var network: InMemoryNetwork
|
||||
|
||||
init {
|
||||
// BriefLogFormatter.initVerbose()
|
||||
}
|
||||
|
||||
fun makeNode(): Pair<SingleMessageRecipient, InMemoryNetwork.Node> {
|
||||
// The manuallyPumped = true bit means that we must call the pump method on the system in order to
|
||||
val (address, builder) = network.createNode(manuallyPumped = true)
|
||||
@ -34,29 +30,32 @@ class InMemoryMessagingTests {
|
||||
return Pair(address, node)
|
||||
}
|
||||
|
||||
fun pumpAll() {
|
||||
nodes.values.forEach { it.pump(false) }
|
||||
}
|
||||
|
||||
// Utilities to help define messaging rounds.
|
||||
fun roundWithPumpings(times: Int, body: () -> Unit) {
|
||||
body()
|
||||
repeat(times) { pumpAll() }
|
||||
}
|
||||
|
||||
fun round(body: () -> Unit) = roundWithPumpings(1, body)
|
||||
|
||||
@Before
|
||||
fun before() {
|
||||
fun setupNetwork() {
|
||||
network = InMemoryNetwork()
|
||||
nodes.clear()
|
||||
}
|
||||
|
||||
@After
|
||||
fun after() {
|
||||
fun stopNetwork() {
|
||||
network.stop()
|
||||
}
|
||||
|
||||
fun pumpAll() = nodes.values.map { it.pump(false) }
|
||||
|
||||
// Keep calling "pump" in rounds until every node in the network reports that it had nothing to do.
|
||||
fun <T> runNetwork(body: () -> T): T {
|
||||
val result = body()
|
||||
while (pumpAll().any { it }) {}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
class InMemoryMessagingTests : TestWithInMemoryNetwork() {
|
||||
init {
|
||||
// BriefLogFormatter.initVerbose()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun topicStringValidation() {
|
||||
TopicStringValidator.check("this.is.ok")
|
||||
@ -82,19 +81,19 @@ class InMemoryMessagingTests {
|
||||
var finalDelivery: Message? = null
|
||||
|
||||
with(node2) {
|
||||
addMessageHandler {
|
||||
send(it, addr3)
|
||||
addMessageHandler { msg, registration ->
|
||||
send(msg, addr3)
|
||||
}
|
||||
}
|
||||
|
||||
with(node3) {
|
||||
addMessageHandler {
|
||||
finalDelivery = it
|
||||
addMessageHandler { msg, registration ->
|
||||
finalDelivery = msg
|
||||
}
|
||||
}
|
||||
|
||||
// Node 1 sends a message and it should end up in finalDelivery, after we pump each node.
|
||||
roundWithPumpings(2) {
|
||||
runNetwork {
|
||||
node1.send(node1.createMessage("test.topic", bits), addr2)
|
||||
}
|
||||
|
||||
@ -110,8 +109,8 @@ class InMemoryMessagingTests {
|
||||
val bits = "test-content".toByteArray()
|
||||
|
||||
var counter = 0
|
||||
listOf(node1, node2, node3).forEach { it.addMessageHandler { counter++ } }
|
||||
round {
|
||||
listOf(node1, node2, node3).forEach { it.addMessageHandler { msg, registration -> counter++ } }
|
||||
runNetwork {
|
||||
node1.send(node2.createMessage("test.topic", bits), network.entireNetwork)
|
||||
}
|
||||
assertEquals(3, counter)
|
||||
|
73
src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
Normal file
73
src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
Normal file
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.messaging
|
||||
|
||||
import contracts.Cash
|
||||
import contracts.CommercialPaper
|
||||
import contracts.protocols.TwoPartyTradeProtocol
|
||||
import core.*
|
||||
import core.testutils.*
|
||||
import org.junit.Test
|
||||
import java.util.logging.Formatter
|
||||
import java.util.logging.Level
|
||||
import java.util.logging.LogRecord
|
||||
import java.util.logging.Logger
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
/**
|
||||
* In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do
|
||||
* it on the ledger atomically. Therefore they must work together to build a transaction.
|
||||
*
|
||||
* We assume that Alessia and Boris already found each other via some market, and have agreed the details already.
|
||||
*/
|
||||
class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
init {
|
||||
Logger.getLogger("").handlers[0].level = Level.ALL
|
||||
Logger.getLogger("").handlers[0].formatter = object : Formatter() {
|
||||
override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n"
|
||||
}
|
||||
Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL
|
||||
}
|
||||
|
||||
@Test
|
||||
fun cashForCP() {
|
||||
val (addr1, node1) = makeNode()
|
||||
val (addr2, node2) = makeNode()
|
||||
|
||||
val tp = TwoPartyTradeProtocol.create()
|
||||
|
||||
transactionGroupFor<ContractState> {
|
||||
// Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob.
|
||||
roots {
|
||||
transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper")
|
||||
transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1")
|
||||
transaction(300.DOLLARS.CASH `owned by` BOB label "bob cash2")
|
||||
}
|
||||
|
||||
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
|
||||
val (aliceFuture, bobFuture) = runNetwork {
|
||||
Pair(
|
||||
tp.runSeller(node1, addr2, lookup("alice's paper"), 1000.DOLLARS, ALICE_KEY,
|
||||
TEST_KEYS_TO_CORP_MAP, DUMMY_TIMESTAMPER),
|
||||
tp.runBuyer(node2, addr1, 1000.DOLLARS, CommercialPaper.State::class.java, bobsWallet,
|
||||
mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP)
|
||||
)
|
||||
}
|
||||
|
||||
val aliceResult: Pair<TimestampedWireTransaction, LedgerTransaction> = aliceFuture.get()
|
||||
val bobResult: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture.get()
|
||||
|
||||
assertEquals(aliceResult, bobResult)
|
||||
|
||||
txns.add(aliceResult.second)
|
||||
|
||||
verify()
|
||||
}
|
||||
}
|
||||
}
|
@ -220,7 +220,7 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
private val inStates = ArrayList<ContractStateRef>()
|
||||
|
||||
fun input(label: String) {
|
||||
inStates.add(labelToRefs[label] ?: throw IllegalArgumentException("Unknown label \"$label\""))
|
||||
inStates.add(label.outputRef)
|
||||
}
|
||||
|
||||
|
||||
@ -235,6 +235,9 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
}
|
||||
|
||||
val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found")
|
||||
val String.outputRef: ContractStateRef get() = labelToRefs[this] ?: throw IllegalArgumentException("Unknown label \"$this\"")
|
||||
|
||||
fun <C : ContractState> lookup(label: String) = StateAndRef(label.output as C, label.outputRef)
|
||||
|
||||
private inner class InternalLedgerTransactionDSL : LedgerTransactionDSL() {
|
||||
fun finaliseAndInsertLabels(time: Instant): LedgerTransaction {
|
||||
@ -268,6 +271,7 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
val label = state.label!!
|
||||
labelToRefs[label] = ContractStateRef(ltx.hash, index)
|
||||
outputsToLabels[state.state] = label
|
||||
labelToOutputs[label] = state.state as T
|
||||
}
|
||||
rootTxns.add(ltx)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user