Two party trading: big changes to support and test serialisation, refactorings ...

This commit is contained in:
Mike Hearn 2015-12-11 14:54:59 +01:00
parent 65c5fa7502
commit 89ba996a3c
4 changed files with 395 additions and 199 deletions

View File

@ -9,19 +9,12 @@
package contracts.protocols
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import contracts.Cash
import contracts.sumCashBy
import core.*
import core.messaging.MessagingSystem
import core.messaging.SingleMessageRecipient
import core.serialization.SerializeableWithKryo
import core.serialization.THREAD_LOCAL_KRYO
import core.messaging.*
import core.serialization.deserialize
import core.serialization.registerDataClass
import core.utilities.continuations.*
import core.utilities.trace
import org.slf4j.LoggerFactory
import java.security.KeyPair
import java.security.PrivateKey
import java.security.PublicKey
@ -47,18 +40,17 @@ import java.security.SecureRandom
* To see an example of how to use this class, look at the unit tests.
*/
abstract class TwoPartyTradeProtocol {
// TODO: Replace some args with the context objects
abstract fun runSeller(
net: MessagingSystem,
otherSide: SingleMessageRecipient,
assetToSell: StateAndRef<OwnableState>,
price: Amount,
myKey: KeyPair,
partyKeyMap: Map<PublicKey, Party>,
timestamper: TimestamperService
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
abstract fun runBuyer(
net: MessagingSystem,
otherSide: SingleMessageRecipient,
acceptablePrice: Amount,
typeToSell: Class<out OwnableState>,
@ -66,24 +58,20 @@ abstract class TwoPartyTradeProtocol {
myKeys: Map<PublicKey, PrivateKey>,
timestamper: TimestamperService,
partyKeyMap: Map<PublicKey, Party>
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
companion object {
@JvmStatic fun create(): TwoPartyTradeProtocol {
return TwoPartyTradeProtocolImpl()
}
}
}
class BuyerInitialArgs(
val acceptablePrice: Amount,
val typeToSell: String
)
private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
companion object {
val TRADE_TOPIC = "com.r3cev.protocols.trade"
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
}
init {
THREAD_LOCAL_KRYO.get().registerDataClass<TwoPartyTradeProtocolImpl.SellerTradeInfo>()
}
class BuyerContext(
val wallet: List<StateAndRef<Cash.State>>,
val myKeys: Map<PublicKey, PrivateKey>,
val timestamper: TimestamperService,
val partyKeyMap: Map<PublicKey, Party>,
val initialArgs: BuyerInitialArgs?
)
// This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and
// can be serialised.
@ -91,7 +79,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
val assetToSell: StateAndRef<OwnableState>,
val price: Amount,
val myKeyPair: KeyPair
) : SerializeableWithKryo
)
// This wraps the things which the seller needs, but which might change whilst the continuation is suspended,
// e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args
@ -99,17 +87,32 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
class SellerContext(
val timestamper: TimestamperService,
val partyKeyMap: Map<PublicKey, Party>,
val initialArgs: SellerInitialArgs?,
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
val initialArgs: SellerInitialArgs?
)
abstract class Buyer : ProtocolStateMachine<BuyerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
abstract class Seller : ProtocolStateMachine<SellerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
companion object {
@JvmStatic fun create(smm: StateMachineManager): TwoPartyTradeProtocol {
return TwoPartyTradeProtocolImpl(smm)
}
}
}
private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : TwoPartyTradeProtocol() {
companion object {
val TRADE_TOPIC = "com.r3cev.protocols.trade"
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount,
val primaryOwnerKey: PublicKey,
val sellerOwnerKey: PublicKey,
val sessionID: Long
) : SerializeableWithKryo
)
// The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled
// by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be
@ -117,8 +120,8 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// interaction with this class must be through either interfaces, or objects passed to and from the continuation
// by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about
// the protocol state machine framework.
class Seller : ProtocolStateMachine<SellerContext> {
override fun run() {
class SellerImpl : Seller() {
override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
val sessionID = makeSessionID()
val args = context().initialArgs!!
@ -153,7 +156,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper)
logger().trace { "Built finished transaction, sending back to secondary!" }
send(TRADE_TOPIC, sessionID, timestamped)
ctx2.resultFuture.set(Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap)))
return Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap))
}
}
@ -162,20 +165,11 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
}
class BuyerContext(
val acceptablePrice: Amount,
val typeToSell: Class<out OwnableState>,
val wallet: List<StateAndRef<Cash.State>>,
val myKeys: Map<PublicKey, PrivateKey>,
val timestamper: TimestamperService,
val partyKeyMap: Map<PublicKey, Party>,
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
)
// The buyer's side of the protocol. See note above Seller to learn about the caveats here.
class Buyer : ProtocolStateMachine<BuyerContext> {
override fun run() {
class BuyerImpl : Buyer() {
override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
val acceptablePrice = context().initialArgs!!.acceptablePrice
val typeToSell = context().initialArgs!!.typeToSell
// Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below,
// as the context object we're meant to use might change each time we suspend (e.g. due to VM restart).
val (stx, theirSessionID) = run {
@ -187,10 +181,10 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// Check the start message for acceptability.
check(tradeRequest.sessionID > 0)
if (tradeRequest.price > ctx.acceptablePrice)
if (tradeRequest.price > acceptablePrice)
throw UnacceptablePriceException(tradeRequest.price)
if (!ctx.typeToSell.isInstance(tradeRequest.assetForSale.state))
throw AssetMismatchException(ctx.typeToSell.name, assetTypeName)
if (!Class.forName(typeToSell).isInstance(tradeRequest.assetForSale.state))
throw AssetMismatchException(typeToSell, assetTypeName)
// TODO: Either look up the stateref here in our local db, or accept a long chain of states and
// validate them to audit the other side and ensure it actually owns the state we are being offered!
@ -199,7 +193,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
// Generate the shared transaction that both sides will sign, using the data we have.
val ptx = PartialTransaction()
// Add input and output states for the movement of cash.
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.primaryOwnerKey, ctx.wallet)
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, ctx.wallet)
// Add inputs/outputs/a command for the movement of the asset.
ptx.addInputState(tradeRequest.assetForSale.ref)
// Just pick some arbitrary public key for now (this provides poor privacy).
@ -219,33 +213,29 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
}
// TODO: Could run verify() here to make sure the only signature missing is the primaries.
logger().trace { "Sending partially signed transaction to primary" }
// We'll just reuse the session ID the primary selected here for convenience.
logger().trace { "Sending partially signed transaction to seller" }
// We'll just reuse the session ID the seller selected here for convenience.
val (ctx, fullySigned) = sendAndReceive<TimestampedWireTransaction, BuyerContext>(TRADE_TOPIC, theirSessionID, theirSessionID, stx)
logger().trace { "Got fully signed transaction, verifying ... "}
val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap)
logger().trace { "Fully signed transaction was valid. Trade complete! :-)" }
ctx.resultFuture.set(Pair(fullySigned, ltx))
return Pair(fullySigned, ltx)
}
}
override fun runSeller(net: MessagingSystem, otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
override fun runSeller(otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
price: Amount, myKey: KeyPair, partyKeyMap: Map<PublicKey, Party>,
timestamper: TimestamperService): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
timestamper: TimestamperService): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
val args = SellerInitialArgs(assetToSell, price, myKey)
val context = SellerContext(timestamper, partyKeyMap, args)
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.primary")
loadContinuationClass<Seller>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
return context.resultFuture
return smm.add(otherSide, context, "$TRADE_TOPIC.seller", SellerImpl::class.java)
}
override fun runBuyer(net: MessagingSystem, otherSide: SingleMessageRecipient, acceptablePrice: Amount,
override fun runBuyer(otherSide: SingleMessageRecipient, acceptablePrice: Amount,
typeToSell: Class<out OwnableState>, wallet: List<StateAndRef<Cash.State>>,
myKeys: Map<PublicKey, PrivateKey>, timestamper: TimestamperService,
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
val context = BuyerContext(acceptablePrice, typeToSell, wallet, myKeys, timestamper, partyKeyMap)
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.secondary")
loadContinuationClass<Buyer>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
return context.resultFuture
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
val context = BuyerContext(wallet, myKeys, timestamper, partyKeyMap, BuyerInitialArgs(acceptablePrice, typeToSell.name))
return smm.add(otherSide, context, "$TRADE_TOPIC.buyer", BuyerImpl::class.java)
}
}

View File

@ -0,0 +1,227 @@
/*
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
* set forth therein.
*
* All other rights reserved.
*/
package core.messaging
import com.esotericsoftware.kryo.io.Input
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import core.serialization.THREAD_LOCAL_KRYO
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.trace
import org.apache.commons.javaflow.Continuation
import org.apache.commons.javaflow.ContinuationClassLoader
import org.objenesis.instantiator.ObjectInstantiator
import org.objenesis.strategy.InstantiatorStrategy
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.*
import java.util.concurrent.Callable
import java.util.concurrent.Executor
/**
* A StateMachineManager is responsible for coordination and persistence of multiple [ProtocolStateMachine] objects.
*
* An implementation of this class will persist state machines to long term storage so they can survive process restarts
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
* (bad for performance, good for programmer mental health!).
*/
class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) {
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
private class Checkpoint(
val continuation: Continuation,
val otherSide: MessageRecipients,
val loggerName: String,
val awaitingTopic: String,
val awaitingObjectOfType: String // java class name
)
constructor(net: MessagingSystem, runInThread: Executor, restoreCheckpoints: List<ByteArray>, resumeStateMachine: (ProtocolStateMachine<*,*>) -> Any) : this(net, runInThread) {
for (bytes in restoreCheckpoints) {
val kryo = createKryo()
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
// rewriting is performed correctly.
var psm: ProtocolStateMachine<*,*>? = null
kryo.instantiatorStrategy = object : InstantiatorStrategy {
val forwardingTo = kryo.instantiatorStrategy
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
return ObjectInstantiator<T> {
psm = loadContinuationClass(type as Class<out ProtocolStateMachine<*, Any>>).first
psm as T
}
} else {
return forwardingTo.newInstantiatorOf(type)
}
}
}
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
val continuation = checkpoint.continuation
val transientContext = resumeStateMachine(psm!!)
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
setupNextMessageHandler(logger, net, continuation, checkpoint.otherSide, awaitingObjectOfType,
checkpoint.awaitingTopic, transientContext, bytes)
}
}
fun <R> add(otherSide: MessageRecipients, transientContext: Any, loggerName: String, continuationClass: Class<out ProtocolStateMachine<*, R>>): ListenableFuture<out R> {
val logger = LoggerFactory.getLogger(loggerName)
val (sm, continuation) = loadContinuationClass<R>(continuationClass)
runInThread.execute {
// The current state of the continuation is held in the closure attached to the messaging system whenever
// the continuation suspends and tells us it expects a response.
iterateStateMachine(continuation, net, otherSide, transientContext, transientContext, logger, null)
}
return sm.resultFuture
}
@Suppress("UNCHECKED_CAST")
private fun <R> loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, R>>): Pair<ProtocolStateMachine<*,R>, Continuation> {
val url = continuationClass.protectionDomain.codeSource.location
val cl = ContinuationClassLoader(arrayOf(url), this.javaClass.classLoader)
val obj = cl.forceLoadClass(continuationClass.name).newInstance() as ProtocolStateMachine<*, R>
return Pair(obj, Continuation.startSuspendedWith(obj))
}
private val checkpoints: LinkedList<ByteArray> = LinkedList()
private fun persistCheckpoint(prev: ByteArray?, new: ByteArray) {
synchronized(checkpoints) {
if (prev == null) {
for (i in checkpoints.size - 1 downTo 0) {
val b = checkpoints[i]
if (Arrays.equals(b, prev)) {
checkpoints[i] = new
return
}
}
}
checkpoints.add(new)
}
}
fun saveToBytes(): LinkedList<ByteArray> = synchronized(checkpoints) { LinkedList(checkpoints) }
private fun iterateStateMachine(c: Continuation, net: MessagingSystem, otherSide: MessageRecipients,
transientContext: Any, continuationInput: Any?, logger: Logger,
prevPersistedBytes: ByteArray?): Continuation {
// This will resume execution of the run() function inside the continuation at the place it left off.
val oldLogger = CONTINUATION_LOGGER.get()
val nextState: Continuation? = try {
CONTINUATION_LOGGER.set(logger)
Continuation.continueWith(c, continuationInput)
} catch (t: Throwable) {
logger.error("Caught error whilst invoking protocol state machine", t)
throw t
} finally {
CONTINUATION_LOGGER.set(oldLogger)
}
// If continuation returns null, it's finished and the result future has been set.
if (nextState == null)
return c
val req = nextState.value() as? ContinuationResult ?: return c
// Else, it wants us to do something: send, receive, or send-and-receive.
if (req is ContinuationResult.ExpectingResponse<*>) {
// Prepare a listener on the network that runs in the background thread when we received a message.
val topic = "${req.topic}.${req.sessionIDForReceive}"
setupNextMessageHandler(logger, net, nextState, otherSide, req.responseType, topic, transientContext, prevPersistedBytes)
}
// If an object to send was provided (not null), send it now.
req.obj?.let {
val topic = "${req.topic}.${req.sessionIDForSend}"
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
net.send(net.createMessage(topic, it.serialize()), otherSide)
}
if (req is ContinuationResult.NotExpectingResponse) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
return iterateStateMachine(nextState, net, otherSide, transientContext, transientContext, logger, prevPersistedBytes)
} else {
return nextState
}
}
private fun setupNextMessageHandler(logger: Logger, net: MessagingSystem, nextState: Continuation,
otherSide: MessageRecipients, responseType: Class<*>,
topic: String, transientContext: Any, prevPersistedBytes: ByteArray?) {
val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name)
persistCheckpoint(prevPersistedBytes, checkpoint.serialize())
net.runOnNextMessage(topic, runInThread) { netMsg ->
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
iterateStateMachine(nextState, net, otherSide, transientContext, Pair(transientContext, obj), logger, prevPersistedBytes)
}
}
}
val CONTINUATION_LOGGER = ThreadLocal<Logger>()
/**
* A convenience mixin interface that can be implemented by an object that will act as a continuation.
*
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
* provides are pre-defined utility methods to ease implementation of such machines.
*/
@Suppress("UNCHECKED_CAST")
abstract class ProtocolStateMachine<CONTEXT_TYPE : Any, R> : Callable<R>, Runnable {
protected fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
protected fun logger(): Logger = CONTINUATION_LOGGER.get()
// These fields shouldn't be serialised.
@Transient private var _resultFuture: SettableFuture<R> = SettableFuture.create<R>()
val resultFuture: ListenableFuture<R> get() = _resultFuture
override fun run() {
val r = call()
if (r != null)
_resultFuture.set(r)
}
}
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
inline fun <S : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.send(topic: String, sessionID: Long, obj: S) =
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
@Suppress("UNCHECKED_CAST")
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.sendAndReceive(
topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any): Pair<CONTEXT_TYPE, R> {
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
}
@Suppress("UNCHECKED_CAST")
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.receive(
topic: String, sessionIDForReceive: Long): Pair<CONTEXT_TYPE, R> {
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
R::class.java)) as Pair<CONTEXT_TYPE, R>
}
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) {
class ExpectingResponse<R : Any>(
topic: String,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
val responseType: Class<R>
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: Any?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
}

View File

@ -1,118 +0,0 @@
/*
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
* set forth therein.
*
* All other rights reserved.
*/
package core.utilities.continuations
import com.esotericsoftware.kryo.io.Output
import core.messaging.MessagingSystem
import core.messaging.SingleMessageRecipient
import core.messaging.runOnNextMessage
import core.serialization.SerializeableWithKryo
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.trace
import org.apache.commons.javaflow.Continuation
import org.slf4j.Logger
import java.io.ByteArrayOutputStream
private val CONTINUATION_LOGGER = ThreadLocal<Logger>()
/**
* A convenience mixing interface that can be implemented by an object that will act as a continuation.
*
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
* provides are pre-defined utility methods to ease implementation of such machines.
*/
@Suppress("UNCHECKED_CAST")
interface ProtocolStateMachine<CONTEXT_TYPE : Any> : Runnable {
fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
fun logger(): Logger = CONTINUATION_LOGGER.get()
}
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
inline fun <S : SerializeableWithKryo,
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.send(topic: String, sessionID: Long, obj: S) =
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
@Suppress("UNCHECKED_CAST")
inline fun <reified R : SerializeableWithKryo,
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.sendAndReceive(topic: String, sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: SerializeableWithKryo) =
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
@Suppress("UNCHECKED_CAST")
inline fun <reified R : SerializeableWithKryo, CONTEXT_TYPE : Any> receive(topic: String, sessionIDForReceive: Long) =
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
R::class.java)) as Pair<CONTEXT_TYPE, R>
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: SerializeableWithKryo?) {
class ExpectingResponse<R : SerializeableWithKryo>(
topic: String,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: SerializeableWithKryo?,
val responseType: Class<R>
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: SerializeableWithKryo?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
}
fun Continuation.iterateStateMachine(net: MessagingSystem, otherSide: SingleMessageRecipient,
transientContext: Any, continuationInput: Any?, logger: Logger): Continuation {
// This will resume execution of the run() function inside the continuation at the place it left off.
val oldLogger = CONTINUATION_LOGGER.get()
val nextState = try {
CONTINUATION_LOGGER.set(logger)
Continuation.continueWith(this, continuationInput)
} catch (t: Throwable) {
logger.error("Caught error whilst invoking protocol state machine", t)
throw t
} finally {
CONTINUATION_LOGGER.set(oldLogger)
}
// If continuation returns null, it's finished.
val req = nextState?.value() as? ContinuationResult ?: return this
// Else, it wants us to do something: send, receive, or send-and-receive. Firstly, checkpoint it, so we can restart
// if something goes wrong.
val bytes = run {
val stream = ByteArrayOutputStream()
Output(stream).use {
createKryo().apply {
isRegistrationRequired = false
writeObject(it, nextState)
}
}
stream.toByteArray()
}
if (req is ContinuationResult.ExpectingResponse<*>) {
val topic = "${req.topic}.${req.sessionIDForReceive}"
net.runOnNextMessage(topic) { netMsg ->
val obj = netMsg.data.deserialize(req.responseType)
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
nextState.iterateStateMachine(net, otherSide, transientContext, Pair(transientContext, obj), logger)
}
}
// If an object to send was provided (not null), send it now.
req.obj?.let {
val topic = "${req.topic}.${req.sessionIDForSend}"
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
net.send(net.createMessage(topic, it.serialize()), otherSide)
}
if (req is ContinuationResult.NotExpectingResponse) {
// We sent a message, but won't get a response, so we must re-enter the continuation to let it keep going.
return nextState.iterateStateMachine(net, otherSide, transientContext, transientContext, logger)
} else {
return nextState
}
}

View File

@ -8,17 +8,24 @@
package core.messaging
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import contracts.Cash
import contracts.CommercialPaper
import contracts.protocols.TwoPartyTradeProtocol
import core.*
import core.testutils.*
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.concurrent.Executors
import java.util.logging.Formatter
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail
/**
* In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do
@ -27,7 +34,8 @@ import kotlin.test.assertEquals
* We assume that Alessia and Boris already found each other via some market, and have agreed the details already.
*/
class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
init {
@Before
fun initLogging() {
Logger.getLogger("").handlers[0].level = Level.ALL
Logger.getLogger("").handlers[0].formatter = object : Formatter() {
override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n"
@ -35,12 +43,19 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL
}
@After
fun stopLogging() {
Logger.getLogger("com.r3cev.protocols.trade").level = Level.INFO
}
@Test
fun cashForCP() {
val (addr1, node1) = makeNode()
val (addr2, node2) = makeNode()
val (addr1, node1) = makeNode(inBackground = true)
val (addr2, node2) = makeNode(inBackground = true)
val tp = TwoPartyTradeProtocol.create()
val backgroundThread = Executors.newSingleThreadExecutor()
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(node1, backgroundThread))
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(node2, backgroundThread))
transactionGroupFor<ContractState> {
// Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob.
@ -51,22 +66,104 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
}
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
val (aliceFuture, bobFuture) = runNetwork {
Pair(
tp.runSeller(node1, addr2, lookup("alice's paper"), 1000.DOLLARS, ALICE_KEY,
TEST_KEYS_TO_CORP_MAP, DUMMY_TIMESTAMPER),
tp.runBuyer(node2, addr1, 1000.DOLLARS, CommercialPaper.State::class.java, bobsWallet,
mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP)
val aliceResult = tpSeller.runSeller(
addr2,
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
TEST_KEYS_TO_CORP_MAP,
DUMMY_TIMESTAMPER
)
val bobResult = tpBuyer.runBuyer(
addr1,
1000.DOLLARS,
CommercialPaper.State::class.java,
bobsWallet,
mapOf(BOB to BOB_KEY.private),
DUMMY_TIMESTAMPER,
TEST_KEYS_TO_CORP_MAP
)
assertEquals(aliceResult.get(), bobResult.get())
txns.add(aliceResult.get().second)
verify()
}
backgroundThread.shutdown()
}
val aliceResult: Pair<TimestampedWireTransaction, LedgerTransaction> = aliceFuture.get()
val bobResult: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture.get()
@Test
fun serializeAndRestore() {
val (addr1, node1) = makeNode(inBackground = false)
var (addr2, node2) = makeNode(inBackground = false)
assertEquals(aliceResult, bobResult)
val smmSeller = StateMachineManager(node1, MoreExecutors.directExecutor())
val tpSeller = TwoPartyTradeProtocol.create(smmSeller)
val smmBuyer = StateMachineManager(node2, MoreExecutors.directExecutor())
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer)
txns.add(aliceResult.second)
transactionGroupFor<ContractState> {
// Buyer Bob has some cash, Seller Alice has some commercial paper she wants to sell to Bob.
roots {
transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper")
transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1")
transaction(300.DOLLARS.CASH `owned by` BOB label "bob cash2")
}
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
tpSeller.runSeller(
addr2,
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
TEST_KEYS_TO_CORP_MAP,
DUMMY_TIMESTAMPER
)
tpBuyer.runBuyer(
addr1,
1000.DOLLARS,
CommercialPaper.State::class.java,
bobsWallet,
mapOf(BOB to BOB_KEY.private),
DUMMY_TIMESTAMPER,
TEST_KEYS_TO_CORP_MAP
)
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
node2.pump(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
val storageBob = smmBuyer.saveToBytes()
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.
node2.stop()
// Alice doesn't know that and sends Bob the now finalised transaction. Alice sends a message to a node
// that has gone offline.
node1.pump(false)
// ... bring the network back up ...
node2 = network.createNodeWithID(true, addr2.id).start().get()
// We must provide the state machines with all the stuff that couldn't be saved to disk.
var bobFuture: ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>? = null
fun resumeStateMachine(forObj: ProtocolStateMachine<*,*>): Any {
return when (forObj) {
is TwoPartyTradeProtocol.Buyer -> {
bobFuture = forObj.resultFuture
return TwoPartyTradeProtocol.BuyerContext(bobsWallet, mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP, null)
}
else -> fail()
}
}
// The act of constructing this object will re-register the message handlers that Bob was waiting on before
// the reboot occurred.
StateMachineManager(node2, MoreExecutors.directExecutor(), storageBob, ::resumeStateMachine)
assertTrue(node2.pump(false))
// Bob is now finished and has the same transaction as Alice.
val tx: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture!!.get()
txns.add(tx.second)
verify()
}
}