Two party trading: big changes to support and test serialisation, refactorings ...

This commit is contained in:
Mike Hearn 2015-12-11 14:54:59 +01:00
parent 65c5fa7502
commit 89ba996a3c
4 changed files with 395 additions and 199 deletions

View File

@ -9,19 +9,12 @@
package contracts.protocols package contracts.protocols
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import contracts.Cash import contracts.Cash
import contracts.sumCashBy import contracts.sumCashBy
import core.* import core.*
import core.messaging.MessagingSystem import core.messaging.*
import core.messaging.SingleMessageRecipient
import core.serialization.SerializeableWithKryo
import core.serialization.THREAD_LOCAL_KRYO
import core.serialization.deserialize import core.serialization.deserialize
import core.serialization.registerDataClass
import core.utilities.continuations.*
import core.utilities.trace import core.utilities.trace
import org.slf4j.LoggerFactory
import java.security.KeyPair import java.security.KeyPair
import java.security.PrivateKey import java.security.PrivateKey
import java.security.PublicKey import java.security.PublicKey
@ -47,18 +40,17 @@ import java.security.SecureRandom
* To see an example of how to use this class, look at the unit tests. * To see an example of how to use this class, look at the unit tests.
*/ */
abstract class TwoPartyTradeProtocol { abstract class TwoPartyTradeProtocol {
// TODO: Replace some args with the context objects
abstract fun runSeller( abstract fun runSeller(
net: MessagingSystem,
otherSide: SingleMessageRecipient, otherSide: SingleMessageRecipient,
assetToSell: StateAndRef<OwnableState>, assetToSell: StateAndRef<OwnableState>,
price: Amount, price: Amount,
myKey: KeyPair, myKey: KeyPair,
partyKeyMap: Map<PublicKey, Party>, partyKeyMap: Map<PublicKey, Party>,
timestamper: TimestamperService timestamper: TimestamperService
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> ): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
abstract fun runBuyer( abstract fun runBuyer(
net: MessagingSystem,
otherSide: SingleMessageRecipient, otherSide: SingleMessageRecipient,
acceptablePrice: Amount, acceptablePrice: Amount,
typeToSell: Class<out OwnableState>, typeToSell: Class<out OwnableState>,
@ -66,24 +58,20 @@ abstract class TwoPartyTradeProtocol {
myKeys: Map<PublicKey, PrivateKey>, myKeys: Map<PublicKey, PrivateKey>,
timestamper: TimestamperService, timestamper: TimestamperService,
partyKeyMap: Map<PublicKey, Party> partyKeyMap: Map<PublicKey, Party>
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> ): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
companion object { class BuyerInitialArgs(
@JvmStatic fun create(): TwoPartyTradeProtocol { val acceptablePrice: Amount,
return TwoPartyTradeProtocolImpl() val typeToSell: String
} )
}
}
private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() { class BuyerContext(
companion object { val wallet: List<StateAndRef<Cash.State>>,
val TRADE_TOPIC = "com.r3cev.protocols.trade" val myKeys: Map<PublicKey, PrivateKey>,
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong()) val timestamper: TimestamperService,
} val partyKeyMap: Map<PublicKey, Party>,
val initialArgs: BuyerInitialArgs?
init { )
THREAD_LOCAL_KRYO.get().registerDataClass<TwoPartyTradeProtocolImpl.SellerTradeInfo>()
}
// This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and // This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and
// can be serialised. // can be serialised.
@ -91,7 +79,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
val assetToSell: StateAndRef<OwnableState>, val assetToSell: StateAndRef<OwnableState>,
val price: Amount, val price: Amount,
val myKeyPair: KeyPair val myKeyPair: KeyPair
) : SerializeableWithKryo )
// This wraps the things which the seller needs, but which might change whilst the continuation is suspended, // This wraps the things which the seller needs, but which might change whilst the continuation is suspended,
// e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args // e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args
@ -99,17 +87,32 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
class SellerContext( class SellerContext(
val timestamper: TimestamperService, val timestamper: TimestamperService,
val partyKeyMap: Map<PublicKey, Party>, val partyKeyMap: Map<PublicKey, Party>,
val initialArgs: SellerInitialArgs?, val initialArgs: SellerInitialArgs?
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
) )
abstract class Buyer : ProtocolStateMachine<BuyerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
abstract class Seller : ProtocolStateMachine<SellerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
companion object {
@JvmStatic fun create(smm: StateMachineManager): TwoPartyTradeProtocol {
return TwoPartyTradeProtocolImpl(smm)
}
}
}
private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : TwoPartyTradeProtocol() {
companion object {
val TRADE_TOPIC = "com.r3cev.protocols.trade"
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer. // This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo( class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>, val assetForSale: StateAndRef<OwnableState>,
val price: Amount, val price: Amount,
val primaryOwnerKey: PublicKey, val sellerOwnerKey: PublicKey,
val sessionID: Long val sessionID: Long
) : SerializeableWithKryo )
// The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled // The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled
// by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be // by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be
@ -117,8 +120,8 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// interaction with this class must be through either interfaces, or objects passed to and from the continuation // interaction with this class must be through either interfaces, or objects passed to and from the continuation
// by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about // by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about
// the protocol state machine framework. // the protocol state machine framework.
class Seller : ProtocolStateMachine<SellerContext> { class SellerImpl : Seller() {
override fun run() { override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
val sessionID = makeSessionID() val sessionID = makeSessionID()
val args = context().initialArgs!! val args = context().initialArgs!!
@ -153,7 +156,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper) val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper)
logger().trace { "Built finished transaction, sending back to secondary!" } logger().trace { "Built finished transaction, sending back to secondary!" }
send(TRADE_TOPIC, sessionID, timestamped) send(TRADE_TOPIC, sessionID, timestamped)
ctx2.resultFuture.set(Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap))) return Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap))
} }
} }
@ -162,20 +165,11 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName" override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
} }
class BuyerContext(
val acceptablePrice: Amount,
val typeToSell: Class<out OwnableState>,
val wallet: List<StateAndRef<Cash.State>>,
val myKeys: Map<PublicKey, PrivateKey>,
val timestamper: TimestamperService,
val partyKeyMap: Map<PublicKey, Party>,
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
)
// The buyer's side of the protocol. See note above Seller to learn about the caveats here. // The buyer's side of the protocol. See note above Seller to learn about the caveats here.
class Buyer : ProtocolStateMachine<BuyerContext> { class BuyerImpl : Buyer() {
override fun run() { override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
val acceptablePrice = context().initialArgs!!.acceptablePrice
val typeToSell = context().initialArgs!!.typeToSell
// Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below, // Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below,
// as the context object we're meant to use might change each time we suspend (e.g. due to VM restart). // as the context object we're meant to use might change each time we suspend (e.g. due to VM restart).
val (stx, theirSessionID) = run { val (stx, theirSessionID) = run {
@ -187,10 +181,10 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// Check the start message for acceptability. // Check the start message for acceptability.
check(tradeRequest.sessionID > 0) check(tradeRequest.sessionID > 0)
if (tradeRequest.price > ctx.acceptablePrice) if (tradeRequest.price > acceptablePrice)
throw UnacceptablePriceException(tradeRequest.price) throw UnacceptablePriceException(tradeRequest.price)
if (!ctx.typeToSell.isInstance(tradeRequest.assetForSale.state)) if (!Class.forName(typeToSell).isInstance(tradeRequest.assetForSale.state))
throw AssetMismatchException(ctx.typeToSell.name, assetTypeName) throw AssetMismatchException(typeToSell, assetTypeName)
// TODO: Either look up the stateref here in our local db, or accept a long chain of states and // TODO: Either look up the stateref here in our local db, or accept a long chain of states and
// validate them to audit the other side and ensure it actually owns the state we are being offered! // validate them to audit the other side and ensure it actually owns the state we are being offered!
@ -199,7 +193,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// Generate the shared transaction that both sides will sign, using the data we have. // Generate the shared transaction that both sides will sign, using the data we have.
val ptx = PartialTransaction() val ptx = PartialTransaction()
// Add input and output states for the movement of cash. // Add input and output states for the movement of cash.
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.primaryOwnerKey, ctx.wallet) val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, ctx.wallet)
// Add inputs/outputs/a command for the movement of the asset. // Add inputs/outputs/a command for the movement of the asset.
ptx.addInputState(tradeRequest.assetForSale.ref) ptx.addInputState(tradeRequest.assetForSale.ref)
// Just pick some arbitrary public key for now (this provides poor privacy). // Just pick some arbitrary public key for now (this provides poor privacy).
@ -219,33 +213,29 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
} }
// TODO: Could run verify() here to make sure the only signature missing is the primaries. // TODO: Could run verify() here to make sure the only signature missing is the primaries.
logger().trace { "Sending partially signed transaction to primary" } logger().trace { "Sending partially signed transaction to seller" }
// We'll just reuse the session ID the primary selected here for convenience. // We'll just reuse the session ID the seller selected here for convenience.
val (ctx, fullySigned) = sendAndReceive<TimestampedWireTransaction, BuyerContext>(TRADE_TOPIC, theirSessionID, theirSessionID, stx) val (ctx, fullySigned) = sendAndReceive<TimestampedWireTransaction, BuyerContext>(TRADE_TOPIC, theirSessionID, theirSessionID, stx)
logger().trace { "Got fully signed transaction, verifying ... "} logger().trace { "Got fully signed transaction, verifying ... "}
val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap) val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap)
logger().trace { "Fully signed transaction was valid. Trade complete! :-)" } logger().trace { "Fully signed transaction was valid. Trade complete! :-)" }
ctx.resultFuture.set(Pair(fullySigned, ltx)) return Pair(fullySigned, ltx)
} }
} }
override fun runSeller(net: MessagingSystem, otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>, override fun runSeller(otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
price: Amount, myKey: KeyPair, partyKeyMap: Map<PublicKey, Party>, price: Amount, myKey: KeyPair, partyKeyMap: Map<PublicKey, Party>,
timestamper: TimestamperService): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> { timestamper: TimestamperService): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
val args = SellerInitialArgs(assetToSell, price, myKey) val args = SellerInitialArgs(assetToSell, price, myKey)
val context = SellerContext(timestamper, partyKeyMap, args) val context = SellerContext(timestamper, partyKeyMap, args)
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.primary") return smm.add(otherSide, context, "$TRADE_TOPIC.seller", SellerImpl::class.java)
loadContinuationClass<Seller>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
return context.resultFuture
} }
override fun runBuyer(net: MessagingSystem, otherSide: SingleMessageRecipient, acceptablePrice: Amount, override fun runBuyer(otherSide: SingleMessageRecipient, acceptablePrice: Amount,
typeToSell: Class<out OwnableState>, wallet: List<StateAndRef<Cash.State>>, typeToSell: Class<out OwnableState>, wallet: List<StateAndRef<Cash.State>>,
myKeys: Map<PublicKey, PrivateKey>, timestamper: TimestamperService, myKeys: Map<PublicKey, PrivateKey>, timestamper: TimestamperService,
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> { partyKeyMap: Map<PublicKey, Party>): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
val context = BuyerContext(acceptablePrice, typeToSell, wallet, myKeys, timestamper, partyKeyMap) val context = BuyerContext(wallet, myKeys, timestamper, partyKeyMap, BuyerInitialArgs(acceptablePrice, typeToSell.name))
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.secondary") return smm.add(otherSide, context, "$TRADE_TOPIC.buyer", BuyerImpl::class.java)
loadContinuationClass<Buyer>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
return context.resultFuture
} }
} }

View File

@ -0,0 +1,227 @@
/*
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
* set forth therein.
*
* All other rights reserved.
*/
package core.messaging
import com.esotericsoftware.kryo.io.Input
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import core.serialization.THREAD_LOCAL_KRYO
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.trace
import org.apache.commons.javaflow.Continuation
import org.apache.commons.javaflow.ContinuationClassLoader
import org.objenesis.instantiator.ObjectInstantiator
import org.objenesis.strategy.InstantiatorStrategy
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.*
import java.util.concurrent.Callable
import java.util.concurrent.Executor
/**
* A StateMachineManager is responsible for coordination and persistence of multiple [ProtocolStateMachine] objects.
*
* An implementation of this class will persist state machines to long term storage so they can survive process restarts
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
* (bad for performance, good for programmer mental health!).
*/
class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) {
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
private class Checkpoint(
val continuation: Continuation,
val otherSide: MessageRecipients,
val loggerName: String,
val awaitingTopic: String,
val awaitingObjectOfType: String // java class name
)
constructor(net: MessagingSystem, runInThread: Executor, restoreCheckpoints: List<ByteArray>, resumeStateMachine: (ProtocolStateMachine<*,*>) -> Any) : this(net, runInThread) {
for (bytes in restoreCheckpoints) {
val kryo = createKryo()
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
// rewriting is performed correctly.
var psm: ProtocolStateMachine<*,*>? = null
kryo.instantiatorStrategy = object : InstantiatorStrategy {
val forwardingTo = kryo.instantiatorStrategy
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
return ObjectInstantiator<T> {
psm = loadContinuationClass(type as Class<out ProtocolStateMachine<*, Any>>).first
psm as T
}
} else {
return forwardingTo.newInstantiatorOf(type)
}
}
}
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
val continuation = checkpoint.continuation
val transientContext = resumeStateMachine(psm!!)
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
setupNextMessageHandler(logger, net, continuation, checkpoint.otherSide, awaitingObjectOfType,
checkpoint.awaitingTopic, transientContext, bytes)
}
}
fun <R> add(otherSide: MessageRecipients, transientContext: Any, loggerName: String, continuationClass: Class<out ProtocolStateMachine<*, R>>): ListenableFuture<out R> {
val logger = LoggerFactory.getLogger(loggerName)
val (sm, continuation) = loadContinuationClass<R>(continuationClass)
runInThread.execute {
// The current state of the continuation is held in the closure attached to the messaging system whenever
// the continuation suspends and tells us it expects a response.
iterateStateMachine(continuation, net, otherSide, transientContext, transientContext, logger, null)
}
return sm.resultFuture
}
@Suppress("UNCHECKED_CAST")
private fun <R> loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, R>>): Pair<ProtocolStateMachine<*,R>, Continuation> {
val url = continuationClass.protectionDomain.codeSource.location
val cl = ContinuationClassLoader(arrayOf(url), this.javaClass.classLoader)
val obj = cl.forceLoadClass(continuationClass.name).newInstance() as ProtocolStateMachine<*, R>
return Pair(obj, Continuation.startSuspendedWith(obj))
}
private val checkpoints: LinkedList<ByteArray> = LinkedList()
private fun persistCheckpoint(prev: ByteArray?, new: ByteArray) {
synchronized(checkpoints) {
if (prev == null) {
for (i in checkpoints.size - 1 downTo 0) {
val b = checkpoints[i]
if (Arrays.equals(b, prev)) {
checkpoints[i] = new
return
}
}
}
checkpoints.add(new)
}
}
fun saveToBytes(): LinkedList<ByteArray> = synchronized(checkpoints) { LinkedList(checkpoints) }
private fun iterateStateMachine(c: Continuation, net: MessagingSystem, otherSide: MessageRecipients,
transientContext: Any, continuationInput: Any?, logger: Logger,
prevPersistedBytes: ByteArray?): Continuation {
// This will resume execution of the run() function inside the continuation at the place it left off.
val oldLogger = CONTINUATION_LOGGER.get()
val nextState: Continuation? = try {
CONTINUATION_LOGGER.set(logger)
Continuation.continueWith(c, continuationInput)
} catch (t: Throwable) {
logger.error("Caught error whilst invoking protocol state machine", t)
throw t
} finally {
CONTINUATION_LOGGER.set(oldLogger)
}
// If continuation returns null, it's finished and the result future has been set.
if (nextState == null)
return c
val req = nextState.value() as? ContinuationResult ?: return c
// Else, it wants us to do something: send, receive, or send-and-receive.
if (req is ContinuationResult.ExpectingResponse<*>) {
// Prepare a listener on the network that runs in the background thread when we received a message.
val topic = "${req.topic}.${req.sessionIDForReceive}"
setupNextMessageHandler(logger, net, nextState, otherSide, req.responseType, topic, transientContext, prevPersistedBytes)
}
// If an object to send was provided (not null), send it now.
req.obj?.let {
val topic = "${req.topic}.${req.sessionIDForSend}"
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
net.send(net.createMessage(topic, it.serialize()), otherSide)
}
if (req is ContinuationResult.NotExpectingResponse) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
return iterateStateMachine(nextState, net, otherSide, transientContext, transientContext, logger, prevPersistedBytes)
} else {
return nextState
}
}
private fun setupNextMessageHandler(logger: Logger, net: MessagingSystem, nextState: Continuation,
otherSide: MessageRecipients, responseType: Class<*>,
topic: String, transientContext: Any, prevPersistedBytes: ByteArray?) {
val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name)
persistCheckpoint(prevPersistedBytes, checkpoint.serialize())
net.runOnNextMessage(topic, runInThread) { netMsg ->
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
iterateStateMachine(nextState, net, otherSide, transientContext, Pair(transientContext, obj), logger, prevPersistedBytes)
}
}
}
val CONTINUATION_LOGGER = ThreadLocal<Logger>()
/**
* A convenience mixin interface that can be implemented by an object that will act as a continuation.
*
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
* provides are pre-defined utility methods to ease implementation of such machines.
*/
@Suppress("UNCHECKED_CAST")
abstract class ProtocolStateMachine<CONTEXT_TYPE : Any, R> : Callable<R>, Runnable {
protected fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
protected fun logger(): Logger = CONTINUATION_LOGGER.get()
// These fields shouldn't be serialised.
@Transient private var _resultFuture: SettableFuture<R> = SettableFuture.create<R>()
val resultFuture: ListenableFuture<R> get() = _resultFuture
override fun run() {
val r = call()
if (r != null)
_resultFuture.set(r)
}
}
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
inline fun <S : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.send(topic: String, sessionID: Long, obj: S) =
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
@Suppress("UNCHECKED_CAST")
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.sendAndReceive(
topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any): Pair<CONTEXT_TYPE, R> {
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
}
@Suppress("UNCHECKED_CAST")
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.receive(
topic: String, sessionIDForReceive: Long): Pair<CONTEXT_TYPE, R> {
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
R::class.java)) as Pair<CONTEXT_TYPE, R>
}
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) {
class ExpectingResponse<R : Any>(
topic: String,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
val responseType: Class<R>
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: Any?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
}

View File

@ -1,118 +0,0 @@
/*
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
* set forth therein.
*
* All other rights reserved.
*/
package core.utilities.continuations
import com.esotericsoftware.kryo.io.Output
import core.messaging.MessagingSystem
import core.messaging.SingleMessageRecipient
import core.messaging.runOnNextMessage
import core.serialization.SerializeableWithKryo
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.trace
import org.apache.commons.javaflow.Continuation
import org.slf4j.Logger
import java.io.ByteArrayOutputStream
private val CONTINUATION_LOGGER = ThreadLocal<Logger>()
/**
* A convenience mixing interface that can be implemented by an object that will act as a continuation.
*
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
* provides are pre-defined utility methods to ease implementation of such machines.
*/
@Suppress("UNCHECKED_CAST")
interface ProtocolStateMachine<CONTEXT_TYPE : Any> : Runnable {
fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
fun logger(): Logger = CONTINUATION_LOGGER.get()
}
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
inline fun <S : SerializeableWithKryo,
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.send(topic: String, sessionID: Long, obj: S) =
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
@Suppress("UNCHECKED_CAST")
inline fun <reified R : SerializeableWithKryo,
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.sendAndReceive(topic: String, sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: SerializeableWithKryo) =
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
@Suppress("UNCHECKED_CAST")
inline fun <reified R : SerializeableWithKryo, CONTEXT_TYPE : Any> receive(topic: String, sessionIDForReceive: Long) =
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
R::class.java)) as Pair<CONTEXT_TYPE, R>
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: SerializeableWithKryo?) {
class ExpectingResponse<R : SerializeableWithKryo>(
topic: String,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: SerializeableWithKryo?,
val responseType: Class<R>
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: SerializeableWithKryo?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
}
fun Continuation.iterateStateMachine(net: MessagingSystem, otherSide: SingleMessageRecipient,
transientContext: Any, continuationInput: Any?, logger: Logger): Continuation {
// This will resume execution of the run() function inside the continuation at the place it left off.
val oldLogger = CONTINUATION_LOGGER.get()
val nextState = try {
CONTINUATION_LOGGER.set(logger)
Continuation.continueWith(this, continuationInput)
} catch (t: Throwable) {
logger.error("Caught error whilst invoking protocol state machine", t)
throw t
} finally {
CONTINUATION_LOGGER.set(oldLogger)
}
// If continuation returns null, it's finished.
val req = nextState?.value() as? ContinuationResult ?: return this
// Else, it wants us to do something: send, receive, or send-and-receive. Firstly, checkpoint it, so we can restart
// if something goes wrong.
val bytes = run {
val stream = ByteArrayOutputStream()
Output(stream).use {
createKryo().apply {
isRegistrationRequired = false
writeObject(it, nextState)
}
}
stream.toByteArray()
}
if (req is ContinuationResult.ExpectingResponse<*>) {
val topic = "${req.topic}.${req.sessionIDForReceive}"
net.runOnNextMessage(topic) { netMsg ->
val obj = netMsg.data.deserialize(req.responseType)
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
nextState.iterateStateMachine(net, otherSide, transientContext, Pair(transientContext, obj), logger)
}
}
// If an object to send was provided (not null), send it now.
req.obj?.let {
val topic = "${req.topic}.${req.sessionIDForSend}"
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
net.send(net.createMessage(topic, it.serialize()), otherSide)
}
if (req is ContinuationResult.NotExpectingResponse) {
// We sent a message, but won't get a response, so we must re-enter the continuation to let it keep going.
return nextState.iterateStateMachine(net, otherSide, transientContext, transientContext, logger)
} else {
return nextState
}
}

View File

@ -8,17 +8,24 @@
package core.messaging package core.messaging
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import contracts.Cash import contracts.Cash
import contracts.CommercialPaper import contracts.CommercialPaper
import contracts.protocols.TwoPartyTradeProtocol import contracts.protocols.TwoPartyTradeProtocol
import core.* import core.*
import core.testutils.* import core.testutils.*
import org.junit.After
import org.junit.Before
import org.junit.Test import org.junit.Test
import java.util.concurrent.Executors
import java.util.logging.Formatter import java.util.logging.Formatter
import java.util.logging.Level import java.util.logging.Level
import java.util.logging.LogRecord import java.util.logging.LogRecord
import java.util.logging.Logger import java.util.logging.Logger
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail
/** /**
* In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do * In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do
@ -27,7 +34,8 @@ import kotlin.test.assertEquals
* We assume that Alessia and Boris already found each other via some market, and have agreed the details already. * We assume that Alessia and Boris already found each other via some market, and have agreed the details already.
*/ */
class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
init { @Before
fun initLogging() {
Logger.getLogger("").handlers[0].level = Level.ALL Logger.getLogger("").handlers[0].level = Level.ALL
Logger.getLogger("").handlers[0].formatter = object : Formatter() { Logger.getLogger("").handlers[0].formatter = object : Formatter() {
override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n" override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n"
@ -35,12 +43,19 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL
} }
@After
fun stopLogging() {
Logger.getLogger("com.r3cev.protocols.trade").level = Level.INFO
}
@Test @Test
fun cashForCP() { fun cashForCP() {
val (addr1, node1) = makeNode() val (addr1, node1) = makeNode(inBackground = true)
val (addr2, node2) = makeNode() val (addr2, node2) = makeNode(inBackground = true)
val tp = TwoPartyTradeProtocol.create() val backgroundThread = Executors.newSingleThreadExecutor()
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(node1, backgroundThread))
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(node2, backgroundThread))
transactionGroupFor<ContractState> { transactionGroupFor<ContractState> {
// Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob. // Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob.
@ -51,22 +66,104 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
} }
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2")) val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
val (aliceFuture, bobFuture) = runNetwork {
Pair( val aliceResult = tpSeller.runSeller(
tp.runSeller(node1, addr2, lookup("alice's paper"), 1000.DOLLARS, ALICE_KEY, addr2,
TEST_KEYS_TO_CORP_MAP, DUMMY_TIMESTAMPER), lookup("alice's paper"),
tp.runBuyer(node2, addr1, 1000.DOLLARS, CommercialPaper.State::class.java, bobsWallet, 1000.DOLLARS,
mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP) ALICE_KEY,
TEST_KEYS_TO_CORP_MAP,
DUMMY_TIMESTAMPER
) )
val bobResult = tpBuyer.runBuyer(
addr1,
1000.DOLLARS,
CommercialPaper.State::class.java,
bobsWallet,
mapOf(BOB to BOB_KEY.private),
DUMMY_TIMESTAMPER,
TEST_KEYS_TO_CORP_MAP
)
assertEquals(aliceResult.get(), bobResult.get())
txns.add(aliceResult.get().second)
verify()
}
backgroundThread.shutdown()
} }
val aliceResult: Pair<TimestampedWireTransaction, LedgerTransaction> = aliceFuture.get() @Test
val bobResult: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture.get() fun serializeAndRestore() {
val (addr1, node1) = makeNode(inBackground = false)
var (addr2, node2) = makeNode(inBackground = false)
assertEquals(aliceResult, bobResult) val smmSeller = StateMachineManager(node1, MoreExecutors.directExecutor())
val tpSeller = TwoPartyTradeProtocol.create(smmSeller)
val smmBuyer = StateMachineManager(node2, MoreExecutors.directExecutor())
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer)
txns.add(aliceResult.second) transactionGroupFor<ContractState> {
// Buyer Bob has some cash, Seller Alice has some commercial paper she wants to sell to Bob.
roots {
transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper")
transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1")
transaction(300.DOLLARS.CASH `owned by` BOB label "bob cash2")
}
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
tpSeller.runSeller(
addr2,
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
TEST_KEYS_TO_CORP_MAP,
DUMMY_TIMESTAMPER
)
tpBuyer.runBuyer(
addr1,
1000.DOLLARS,
CommercialPaper.State::class.java,
bobsWallet,
mapOf(BOB to BOB_KEY.private),
DUMMY_TIMESTAMPER,
TEST_KEYS_TO_CORP_MAP
)
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
node2.pump(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
val storageBob = smmBuyer.saveToBytes()
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.
node2.stop()
// Alice doesn't know that and sends Bob the now finalised transaction. Alice sends a message to a node
// that has gone offline.
node1.pump(false)
// ... bring the network back up ...
node2 = network.createNodeWithID(true, addr2.id).start().get()
// We must provide the state machines with all the stuff that couldn't be saved to disk.
var bobFuture: ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>? = null
fun resumeStateMachine(forObj: ProtocolStateMachine<*,*>): Any {
return when (forObj) {
is TwoPartyTradeProtocol.Buyer -> {
bobFuture = forObj.resultFuture
return TwoPartyTradeProtocol.BuyerContext(bobsWallet, mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP, null)
}
else -> fail()
}
}
// The act of constructing this object will re-register the message handlers that Bob was waiting on before
// the reboot occurred.
StateMachineManager(node2, MoreExecutors.directExecutor(), storageBob, ::resumeStateMachine)
assertTrue(node2.pump(false))
// Bob is now finished and has the same transaction as Alice.
val tx: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture!!.get()
txns.add(tx.second)
verify() verify()
} }
} }