mirror of
https://github.com/corda/corda.git
synced 2024-12-19 04:57:58 +00:00
Two party trading: big changes to support and test serialisation, refactorings ...
This commit is contained in:
parent
65c5fa7502
commit
89ba996a3c
@ -9,19 +9,12 @@
|
||||
package contracts.protocols
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import contracts.Cash
|
||||
import contracts.sumCashBy
|
||||
import core.*
|
||||
import core.messaging.MessagingSystem
|
||||
import core.messaging.SingleMessageRecipient
|
||||
import core.serialization.SerializeableWithKryo
|
||||
import core.serialization.THREAD_LOCAL_KRYO
|
||||
import core.messaging.*
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.registerDataClass
|
||||
import core.utilities.continuations.*
|
||||
import core.utilities.trace
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.security.KeyPair
|
||||
import java.security.PrivateKey
|
||||
import java.security.PublicKey
|
||||
@ -47,18 +40,17 @@ import java.security.SecureRandom
|
||||
* To see an example of how to use this class, look at the unit tests.
|
||||
*/
|
||||
abstract class TwoPartyTradeProtocol {
|
||||
// TODO: Replace some args with the context objects
|
||||
abstract fun runSeller(
|
||||
net: MessagingSystem,
|
||||
otherSide: SingleMessageRecipient,
|
||||
assetToSell: StateAndRef<OwnableState>,
|
||||
price: Amount,
|
||||
myKey: KeyPair,
|
||||
partyKeyMap: Map<PublicKey, Party>,
|
||||
timestamper: TimestamperService
|
||||
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
|
||||
abstract fun runBuyer(
|
||||
net: MessagingSystem,
|
||||
otherSide: SingleMessageRecipient,
|
||||
acceptablePrice: Amount,
|
||||
typeToSell: Class<out OwnableState>,
|
||||
@ -66,24 +58,20 @@ abstract class TwoPartyTradeProtocol {
|
||||
myKeys: Map<PublicKey, PrivateKey>,
|
||||
timestamper: TimestamperService,
|
||||
partyKeyMap: Map<PublicKey, Party>
|
||||
): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
|
||||
companion object {
|
||||
@JvmStatic fun create(): TwoPartyTradeProtocol {
|
||||
return TwoPartyTradeProtocolImpl()
|
||||
}
|
||||
}
|
||||
}
|
||||
class BuyerInitialArgs(
|
||||
val acceptablePrice: Amount,
|
||||
val typeToSell: String
|
||||
)
|
||||
|
||||
private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
companion object {
|
||||
val TRADE_TOPIC = "com.r3cev.protocols.trade"
|
||||
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
|
||||
}
|
||||
|
||||
init {
|
||||
THREAD_LOCAL_KRYO.get().registerDataClass<TwoPartyTradeProtocolImpl.SellerTradeInfo>()
|
||||
}
|
||||
class BuyerContext(
|
||||
val wallet: List<StateAndRef<Cash.State>>,
|
||||
val myKeys: Map<PublicKey, PrivateKey>,
|
||||
val timestamper: TimestamperService,
|
||||
val partyKeyMap: Map<PublicKey, Party>,
|
||||
val initialArgs: BuyerInitialArgs?
|
||||
)
|
||||
|
||||
// This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and
|
||||
// can be serialised.
|
||||
@ -91,7 +79,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
val assetToSell: StateAndRef<OwnableState>,
|
||||
val price: Amount,
|
||||
val myKeyPair: KeyPair
|
||||
) : SerializeableWithKryo
|
||||
)
|
||||
|
||||
// This wraps the things which the seller needs, but which might change whilst the continuation is suspended,
|
||||
// e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args
|
||||
@ -99,17 +87,32 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
class SellerContext(
|
||||
val timestamper: TimestamperService,
|
||||
val partyKeyMap: Map<PublicKey, Party>,
|
||||
val initialArgs: SellerInitialArgs?,
|
||||
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
|
||||
val initialArgs: SellerInitialArgs?
|
||||
)
|
||||
|
||||
abstract class Buyer : ProtocolStateMachine<BuyerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
|
||||
abstract class Seller : ProtocolStateMachine<SellerContext, Pair<TimestampedWireTransaction, LedgerTransaction>>()
|
||||
|
||||
companion object {
|
||||
@JvmStatic fun create(smm: StateMachineManager): TwoPartyTradeProtocol {
|
||||
return TwoPartyTradeProtocolImpl(smm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : TwoPartyTradeProtocol() {
|
||||
companion object {
|
||||
val TRADE_TOPIC = "com.r3cev.protocols.trade"
|
||||
fun makeSessionID() = Math.abs(SecureRandom.getInstanceStrong().nextLong())
|
||||
}
|
||||
|
||||
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
|
||||
class SellerTradeInfo(
|
||||
val assetForSale: StateAndRef<OwnableState>,
|
||||
val price: Amount,
|
||||
val primaryOwnerKey: PublicKey,
|
||||
val sellerOwnerKey: PublicKey,
|
||||
val sessionID: Long
|
||||
) : SerializeableWithKryo
|
||||
)
|
||||
|
||||
// The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled
|
||||
// by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be
|
||||
@ -117,8 +120,8 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
// interaction with this class must be through either interfaces, or objects passed to and from the continuation
|
||||
// by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about
|
||||
// the protocol state machine framework.
|
||||
class Seller : ProtocolStateMachine<SellerContext> {
|
||||
override fun run() {
|
||||
class SellerImpl : Seller() {
|
||||
override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
|
||||
val sessionID = makeSessionID()
|
||||
val args = context().initialArgs!!
|
||||
|
||||
@ -153,7 +156,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper)
|
||||
logger().trace { "Built finished transaction, sending back to secondary!" }
|
||||
send(TRADE_TOPIC, sessionID, timestamped)
|
||||
ctx2.resultFuture.set(Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap)))
|
||||
return Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap))
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,20 +165,11 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
|
||||
}
|
||||
|
||||
|
||||
class BuyerContext(
|
||||
val acceptablePrice: Amount,
|
||||
val typeToSell: Class<out OwnableState>,
|
||||
val wallet: List<StateAndRef<Cash.State>>,
|
||||
val myKeys: Map<PublicKey, PrivateKey>,
|
||||
val timestamper: TimestamperService,
|
||||
val partyKeyMap: Map<PublicKey, Party>,
|
||||
val resultFuture: SettableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> = SettableFuture.create()
|
||||
)
|
||||
|
||||
// The buyer's side of the protocol. See note above Seller to learn about the caveats here.
|
||||
class Buyer : ProtocolStateMachine<BuyerContext> {
|
||||
override fun run() {
|
||||
class BuyerImpl : Buyer() {
|
||||
override fun call(): Pair<TimestampedWireTransaction, LedgerTransaction> {
|
||||
val acceptablePrice = context().initialArgs!!.acceptablePrice
|
||||
val typeToSell = context().initialArgs!!.typeToSell
|
||||
// Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below,
|
||||
// as the context object we're meant to use might change each time we suspend (e.g. due to VM restart).
|
||||
val (stx, theirSessionID) = run {
|
||||
@ -187,10 +181,10 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
|
||||
// Check the start message for acceptability.
|
||||
check(tradeRequest.sessionID > 0)
|
||||
if (tradeRequest.price > ctx.acceptablePrice)
|
||||
if (tradeRequest.price > acceptablePrice)
|
||||
throw UnacceptablePriceException(tradeRequest.price)
|
||||
if (!ctx.typeToSell.isInstance(tradeRequest.assetForSale.state))
|
||||
throw AssetMismatchException(ctx.typeToSell.name, assetTypeName)
|
||||
if (!Class.forName(typeToSell).isInstance(tradeRequest.assetForSale.state))
|
||||
throw AssetMismatchException(typeToSell, assetTypeName)
|
||||
|
||||
// TODO: Either look up the stateref here in our local db, or accept a long chain of states and
|
||||
// validate them to audit the other side and ensure it actually owns the state we are being offered!
|
||||
@ -199,7 +193,7 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
// Generate the shared transaction that both sides will sign, using the data we have.
|
||||
val ptx = PartialTransaction()
|
||||
// Add input and output states for the movement of cash.
|
||||
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.primaryOwnerKey, ctx.wallet)
|
||||
val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, ctx.wallet)
|
||||
// Add inputs/outputs/a command for the movement of the asset.
|
||||
ptx.addInputState(tradeRequest.assetForSale.ref)
|
||||
// Just pick some arbitrary public key for now (this provides poor privacy).
|
||||
@ -219,33 +213,29 @@ private class TwoPartyTradeProtocolImpl : TwoPartyTradeProtocol() {
|
||||
}
|
||||
|
||||
// TODO: Could run verify() here to make sure the only signature missing is the primaries.
|
||||
logger().trace { "Sending partially signed transaction to primary" }
|
||||
// We'll just reuse the session ID the primary selected here for convenience.
|
||||
logger().trace { "Sending partially signed transaction to seller" }
|
||||
// We'll just reuse the session ID the seller selected here for convenience.
|
||||
val (ctx, fullySigned) = sendAndReceive<TimestampedWireTransaction, BuyerContext>(TRADE_TOPIC, theirSessionID, theirSessionID, stx)
|
||||
logger().trace { "Got fully signed transaction, verifying ... "}
|
||||
val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap)
|
||||
logger().trace { "Fully signed transaction was valid. Trade complete! :-)" }
|
||||
ctx.resultFuture.set(Pair(fullySigned, ltx))
|
||||
return Pair(fullySigned, ltx)
|
||||
}
|
||||
}
|
||||
|
||||
override fun runSeller(net: MessagingSystem, otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
|
||||
override fun runSeller(otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>,
|
||||
price: Amount, myKey: KeyPair, partyKeyMap: Map<PublicKey, Party>,
|
||||
timestamper: TimestamperService): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
timestamper: TimestamperService): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
val args = SellerInitialArgs(assetToSell, price, myKey)
|
||||
val context = SellerContext(timestamper, partyKeyMap, args)
|
||||
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.primary")
|
||||
loadContinuationClass<Seller>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
|
||||
return context.resultFuture
|
||||
return smm.add(otherSide, context, "$TRADE_TOPIC.seller", SellerImpl::class.java)
|
||||
}
|
||||
|
||||
override fun runBuyer(net: MessagingSystem, otherSide: SingleMessageRecipient, acceptablePrice: Amount,
|
||||
override fun runBuyer(otherSide: SingleMessageRecipient, acceptablePrice: Amount,
|
||||
typeToSell: Class<out OwnableState>, wallet: List<StateAndRef<Cash.State>>,
|
||||
myKeys: Map<PublicKey, PrivateKey>, timestamper: TimestamperService,
|
||||
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
val context = BuyerContext(acceptablePrice, typeToSell, wallet, myKeys, timestamper, partyKeyMap)
|
||||
val logger = LoggerFactory.getLogger("$TRADE_TOPIC.secondary")
|
||||
loadContinuationClass<Buyer>(javaClass.classLoader).iterateStateMachine(net, otherSide, context, context, logger)
|
||||
return context.resultFuture
|
||||
partyKeyMap: Map<PublicKey, Party>): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
val context = BuyerContext(wallet, myKeys, timestamper, partyKeyMap, BuyerInitialArgs(acceptablePrice, typeToSell.name))
|
||||
return smm.add(otherSide, context, "$TRADE_TOPIC.buyer", BuyerImpl::class.java)
|
||||
}
|
||||
}
|
227
src/main/kotlin/core/messaging/StateMachines.kt
Normal file
227
src/main/kotlin/core/messaging/StateMachines.kt
Normal file
@ -0,0 +1,227 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.messaging
|
||||
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import core.serialization.THREAD_LOCAL_KRYO
|
||||
import core.serialization.createKryo
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import core.utilities.trace
|
||||
import org.apache.commons.javaflow.Continuation
|
||||
import org.apache.commons.javaflow.ContinuationClassLoader
|
||||
import org.objenesis.instantiator.ObjectInstantiator
|
||||
import org.objenesis.strategy.InstantiatorStrategy
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.util.*
|
||||
import java.util.concurrent.Callable
|
||||
import java.util.concurrent.Executor
|
||||
|
||||
/**
|
||||
* A StateMachineManager is responsible for coordination and persistence of multiple [ProtocolStateMachine] objects.
|
||||
*
|
||||
* An implementation of this class will persist state machines to long term storage so they can survive process restarts
|
||||
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
|
||||
* (bad for performance, good for programmer mental health!).
|
||||
*/
|
||||
class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) {
|
||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||
private class Checkpoint(
|
||||
val continuation: Continuation,
|
||||
val otherSide: MessageRecipients,
|
||||
val loggerName: String,
|
||||
val awaitingTopic: String,
|
||||
val awaitingObjectOfType: String // java class name
|
||||
)
|
||||
|
||||
constructor(net: MessagingSystem, runInThread: Executor, restoreCheckpoints: List<ByteArray>, resumeStateMachine: (ProtocolStateMachine<*,*>) -> Any) : this(net, runInThread) {
|
||||
for (bytes in restoreCheckpoints) {
|
||||
val kryo = createKryo()
|
||||
|
||||
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
|
||||
// rewriting is performed correctly.
|
||||
var psm: ProtocolStateMachine<*,*>? = null
|
||||
kryo.instantiatorStrategy = object : InstantiatorStrategy {
|
||||
val forwardingTo = kryo.instantiatorStrategy
|
||||
|
||||
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
|
||||
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
|
||||
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
|
||||
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
|
||||
return ObjectInstantiator<T> {
|
||||
psm = loadContinuationClass(type as Class<out ProtocolStateMachine<*, Any>>).first
|
||||
psm as T
|
||||
}
|
||||
} else {
|
||||
return forwardingTo.newInstantiatorOf(type)
|
||||
}
|
||||
}
|
||||
}
|
||||
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
|
||||
|
||||
val continuation = checkpoint.continuation
|
||||
val transientContext = resumeStateMachine(psm!!)
|
||||
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
|
||||
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
|
||||
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
|
||||
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
|
||||
setupNextMessageHandler(logger, net, continuation, checkpoint.otherSide, awaitingObjectOfType,
|
||||
checkpoint.awaitingTopic, transientContext, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fun <R> add(otherSide: MessageRecipients, transientContext: Any, loggerName: String, continuationClass: Class<out ProtocolStateMachine<*, R>>): ListenableFuture<out R> {
|
||||
val logger = LoggerFactory.getLogger(loggerName)
|
||||
val (sm, continuation) = loadContinuationClass<R>(continuationClass)
|
||||
runInThread.execute {
|
||||
// The current state of the continuation is held in the closure attached to the messaging system whenever
|
||||
// the continuation suspends and tells us it expects a response.
|
||||
iterateStateMachine(continuation, net, otherSide, transientContext, transientContext, logger, null)
|
||||
}
|
||||
return sm.resultFuture
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun <R> loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, R>>): Pair<ProtocolStateMachine<*,R>, Continuation> {
|
||||
val url = continuationClass.protectionDomain.codeSource.location
|
||||
val cl = ContinuationClassLoader(arrayOf(url), this.javaClass.classLoader)
|
||||
val obj = cl.forceLoadClass(continuationClass.name).newInstance() as ProtocolStateMachine<*, R>
|
||||
return Pair(obj, Continuation.startSuspendedWith(obj))
|
||||
}
|
||||
|
||||
private val checkpoints: LinkedList<ByteArray> = LinkedList()
|
||||
private fun persistCheckpoint(prev: ByteArray?, new: ByteArray) {
|
||||
synchronized(checkpoints) {
|
||||
if (prev == null) {
|
||||
for (i in checkpoints.size - 1 downTo 0) {
|
||||
val b = checkpoints[i]
|
||||
if (Arrays.equals(b, prev)) {
|
||||
checkpoints[i] = new
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
checkpoints.add(new)
|
||||
}
|
||||
}
|
||||
|
||||
fun saveToBytes(): LinkedList<ByteArray> = synchronized(checkpoints) { LinkedList(checkpoints) }
|
||||
|
||||
private fun iterateStateMachine(c: Continuation, net: MessagingSystem, otherSide: MessageRecipients,
|
||||
transientContext: Any, continuationInput: Any?, logger: Logger,
|
||||
prevPersistedBytes: ByteArray?): Continuation {
|
||||
// This will resume execution of the run() function inside the continuation at the place it left off.
|
||||
val oldLogger = CONTINUATION_LOGGER.get()
|
||||
val nextState: Continuation? = try {
|
||||
CONTINUATION_LOGGER.set(logger)
|
||||
Continuation.continueWith(c, continuationInput)
|
||||
} catch (t: Throwable) {
|
||||
logger.error("Caught error whilst invoking protocol state machine", t)
|
||||
throw t
|
||||
} finally {
|
||||
CONTINUATION_LOGGER.set(oldLogger)
|
||||
}
|
||||
|
||||
// If continuation returns null, it's finished and the result future has been set.
|
||||
if (nextState == null)
|
||||
return c
|
||||
|
||||
val req = nextState.value() as? ContinuationResult ?: return c
|
||||
|
||||
// Else, it wants us to do something: send, receive, or send-and-receive.
|
||||
if (req is ContinuationResult.ExpectingResponse<*>) {
|
||||
// Prepare a listener on the network that runs in the background thread when we received a message.
|
||||
val topic = "${req.topic}.${req.sessionIDForReceive}"
|
||||
setupNextMessageHandler(logger, net, nextState, otherSide, req.responseType, topic, transientContext, prevPersistedBytes)
|
||||
}
|
||||
// If an object to send was provided (not null), send it now.
|
||||
req.obj?.let {
|
||||
val topic = "${req.topic}.${req.sessionIDForSend}"
|
||||
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
|
||||
net.send(net.createMessage(topic, it.serialize()), otherSide)
|
||||
}
|
||||
if (req is ContinuationResult.NotExpectingResponse) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
return iterateStateMachine(nextState, net, otherSide, transientContext, transientContext, logger, prevPersistedBytes)
|
||||
} else {
|
||||
return nextState
|
||||
}
|
||||
}
|
||||
|
||||
private fun setupNextMessageHandler(logger: Logger, net: MessagingSystem, nextState: Continuation,
|
||||
otherSide: MessageRecipients, responseType: Class<*>,
|
||||
topic: String, transientContext: Any, prevPersistedBytes: ByteArray?) {
|
||||
val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name)
|
||||
persistCheckpoint(prevPersistedBytes, checkpoint.serialize())
|
||||
net.runOnNextMessage(topic, runInThread) { netMsg ->
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)
|
||||
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
|
||||
iterateStateMachine(nextState, net, otherSide, transientContext, Pair(transientContext, obj), logger, prevPersistedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val CONTINUATION_LOGGER = ThreadLocal<Logger>()
|
||||
|
||||
/**
|
||||
* A convenience mixin interface that can be implemented by an object that will act as a continuation.
|
||||
*
|
||||
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
|
||||
* provides are pre-defined utility methods to ease implementation of such machines.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
abstract class ProtocolStateMachine<CONTEXT_TYPE : Any, R> : Callable<R>, Runnable {
|
||||
protected fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
|
||||
protected fun logger(): Logger = CONTINUATION_LOGGER.get()
|
||||
|
||||
// These fields shouldn't be serialised.
|
||||
@Transient private var _resultFuture: SettableFuture<R> = SettableFuture.create<R>()
|
||||
|
||||
val resultFuture: ListenableFuture<R> get() = _resultFuture
|
||||
|
||||
override fun run() {
|
||||
val r = call()
|
||||
if (r != null)
|
||||
_resultFuture.set(r)
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
|
||||
inline fun <S : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.send(topic: String, sessionID: Long, obj: S) =
|
||||
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.sendAndReceive(
|
||||
topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any): Pair<CONTEXT_TYPE, R> {
|
||||
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
|
||||
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
}
|
||||
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : Any, CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE, *>.receive(
|
||||
topic: String, sessionIDForReceive: Long): Pair<CONTEXT_TYPE, R> {
|
||||
return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
|
||||
R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
}
|
||||
|
||||
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) {
|
||||
class ExpectingResponse<R : Any>(
|
||||
topic: String,
|
||||
sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: Any?,
|
||||
val responseType: Class<R>
|
||||
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
|
||||
|
||||
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: Any?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
|
||||
}
|
@ -1,118 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.utilities.continuations
|
||||
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import core.messaging.MessagingSystem
|
||||
import core.messaging.SingleMessageRecipient
|
||||
import core.messaging.runOnNextMessage
|
||||
import core.serialization.SerializeableWithKryo
|
||||
import core.serialization.createKryo
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import core.utilities.trace
|
||||
import org.apache.commons.javaflow.Continuation
|
||||
import org.slf4j.Logger
|
||||
import java.io.ByteArrayOutputStream
|
||||
|
||||
private val CONTINUATION_LOGGER = ThreadLocal<Logger>()
|
||||
|
||||
/**
|
||||
* A convenience mixing interface that can be implemented by an object that will act as a continuation.
|
||||
*
|
||||
* A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface
|
||||
* provides are pre-defined utility methods to ease implementation of such machines.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
interface ProtocolStateMachine<CONTEXT_TYPE : Any> : Runnable {
|
||||
fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE
|
||||
fun logger(): Logger = CONTINUATION_LOGGER.get()
|
||||
}
|
||||
|
||||
@Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST")
|
||||
inline fun <S : SerializeableWithKryo,
|
||||
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.send(topic: String, sessionID: Long, obj: S) =
|
||||
Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : SerializeableWithKryo,
|
||||
CONTEXT_TYPE : Any> ProtocolStateMachine<CONTEXT_TYPE>.sendAndReceive(topic: String, sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: SerializeableWithKryo) =
|
||||
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive,
|
||||
obj, R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified R : SerializeableWithKryo, CONTEXT_TYPE : Any> receive(topic: String, sessionIDForReceive: Long) =
|
||||
Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null,
|
||||
R::class.java)) as Pair<CONTEXT_TYPE, R>
|
||||
|
||||
open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: SerializeableWithKryo?) {
|
||||
class ExpectingResponse<R : SerializeableWithKryo>(
|
||||
topic: String,
|
||||
sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: SerializeableWithKryo?,
|
||||
val responseType: Class<R>
|
||||
) : ContinuationResult(topic, sessionIDForSend, sessionIDForReceive, obj)
|
||||
|
||||
class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: SerializeableWithKryo?) : ContinuationResult(topic, sessionIDForSend, -1, obj)
|
||||
}
|
||||
|
||||
fun Continuation.iterateStateMachine(net: MessagingSystem, otherSide: SingleMessageRecipient,
|
||||
transientContext: Any, continuationInput: Any?, logger: Logger): Continuation {
|
||||
// This will resume execution of the run() function inside the continuation at the place it left off.
|
||||
val oldLogger = CONTINUATION_LOGGER.get()
|
||||
val nextState = try {
|
||||
CONTINUATION_LOGGER.set(logger)
|
||||
Continuation.continueWith(this, continuationInput)
|
||||
} catch (t: Throwable) {
|
||||
logger.error("Caught error whilst invoking protocol state machine", t)
|
||||
throw t
|
||||
} finally {
|
||||
CONTINUATION_LOGGER.set(oldLogger)
|
||||
}
|
||||
// If continuation returns null, it's finished.
|
||||
val req = nextState?.value() as? ContinuationResult ?: return this
|
||||
|
||||
// Else, it wants us to do something: send, receive, or send-and-receive. Firstly, checkpoint it, so we can restart
|
||||
// if something goes wrong.
|
||||
val bytes = run {
|
||||
val stream = ByteArrayOutputStream()
|
||||
Output(stream).use {
|
||||
createKryo().apply {
|
||||
isRegistrationRequired = false
|
||||
writeObject(it, nextState)
|
||||
}
|
||||
}
|
||||
stream.toByteArray()
|
||||
}
|
||||
|
||||
if (req is ContinuationResult.ExpectingResponse<*>) {
|
||||
val topic = "${req.topic}.${req.sessionIDForReceive}"
|
||||
net.runOnNextMessage(topic) { netMsg ->
|
||||
val obj = netMsg.data.deserialize(req.responseType)
|
||||
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
|
||||
nextState.iterateStateMachine(net, otherSide, transientContext, Pair(transientContext, obj), logger)
|
||||
}
|
||||
}
|
||||
// If an object to send was provided (not null), send it now.
|
||||
req.obj?.let {
|
||||
val topic = "${req.topic}.${req.sessionIDForSend}"
|
||||
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
|
||||
net.send(net.createMessage(topic, it.serialize()), otherSide)
|
||||
}
|
||||
if (req is ContinuationResult.NotExpectingResponse) {
|
||||
// We sent a message, but won't get a response, so we must re-enter the continuation to let it keep going.
|
||||
return nextState.iterateStateMachine(net, otherSide, transientContext, transientContext, logger)
|
||||
} else {
|
||||
return nextState
|
||||
}
|
||||
}
|
@ -8,17 +8,24 @@
|
||||
|
||||
package core.messaging
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import contracts.Cash
|
||||
import contracts.CommercialPaper
|
||||
import contracts.protocols.TwoPartyTradeProtocol
|
||||
import core.*
|
||||
import core.testutils.*
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.logging.Formatter
|
||||
import java.util.logging.Level
|
||||
import java.util.logging.LogRecord
|
||||
import java.util.logging.Logger
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import kotlin.test.fail
|
||||
|
||||
/**
|
||||
* In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do
|
||||
@ -27,7 +34,8 @@ import kotlin.test.assertEquals
|
||||
* We assume that Alessia and Boris already found each other via some market, and have agreed the details already.
|
||||
*/
|
||||
class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
init {
|
||||
@Before
|
||||
fun initLogging() {
|
||||
Logger.getLogger("").handlers[0].level = Level.ALL
|
||||
Logger.getLogger("").handlers[0].formatter = object : Formatter() {
|
||||
override fun format(record: LogRecord) = "${record.threadID} ${record.loggerName}: ${record.message}\n"
|
||||
@ -35,12 +43,19 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
Logger.getLogger("com.r3cev.protocols.trade").level = Level.ALL
|
||||
}
|
||||
|
||||
@After
|
||||
fun stopLogging() {
|
||||
Logger.getLogger("com.r3cev.protocols.trade").level = Level.INFO
|
||||
}
|
||||
|
||||
@Test
|
||||
fun cashForCP() {
|
||||
val (addr1, node1) = makeNode()
|
||||
val (addr2, node2) = makeNode()
|
||||
val (addr1, node1) = makeNode(inBackground = true)
|
||||
val (addr2, node2) = makeNode(inBackground = true)
|
||||
|
||||
val tp = TwoPartyTradeProtocol.create()
|
||||
val backgroundThread = Executors.newSingleThreadExecutor()
|
||||
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(node1, backgroundThread))
|
||||
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(node2, backgroundThread))
|
||||
|
||||
transactionGroupFor<ContractState> {
|
||||
// Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob.
|
||||
@ -51,22 +66,104 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
}
|
||||
|
||||
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
|
||||
val (aliceFuture, bobFuture) = runNetwork {
|
||||
Pair(
|
||||
tp.runSeller(node1, addr2, lookup("alice's paper"), 1000.DOLLARS, ALICE_KEY,
|
||||
TEST_KEYS_TO_CORP_MAP, DUMMY_TIMESTAMPER),
|
||||
tp.runBuyer(node2, addr1, 1000.DOLLARS, CommercialPaper.State::class.java, bobsWallet,
|
||||
mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP)
|
||||
)
|
||||
|
||||
val aliceResult = tpSeller.runSeller(
|
||||
addr2,
|
||||
lookup("alice's paper"),
|
||||
1000.DOLLARS,
|
||||
ALICE_KEY,
|
||||
TEST_KEYS_TO_CORP_MAP,
|
||||
DUMMY_TIMESTAMPER
|
||||
)
|
||||
val bobResult = tpBuyer.runBuyer(
|
||||
addr1,
|
||||
1000.DOLLARS,
|
||||
CommercialPaper.State::class.java,
|
||||
bobsWallet,
|
||||
mapOf(BOB to BOB_KEY.private),
|
||||
DUMMY_TIMESTAMPER,
|
||||
TEST_KEYS_TO_CORP_MAP
|
||||
)
|
||||
|
||||
assertEquals(aliceResult.get(), bobResult.get())
|
||||
|
||||
txns.add(aliceResult.get().second)
|
||||
verify()
|
||||
}
|
||||
backgroundThread.shutdown()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun serializeAndRestore() {
|
||||
val (addr1, node1) = makeNode(inBackground = false)
|
||||
var (addr2, node2) = makeNode(inBackground = false)
|
||||
|
||||
val smmSeller = StateMachineManager(node1, MoreExecutors.directExecutor())
|
||||
val tpSeller = TwoPartyTradeProtocol.create(smmSeller)
|
||||
val smmBuyer = StateMachineManager(node2, MoreExecutors.directExecutor())
|
||||
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer)
|
||||
|
||||
transactionGroupFor<ContractState> {
|
||||
// Buyer Bob has some cash, Seller Alice has some commercial paper she wants to sell to Bob.
|
||||
roots {
|
||||
transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper")
|
||||
transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1")
|
||||
transaction(300.DOLLARS.CASH `owned by` BOB label "bob cash2")
|
||||
}
|
||||
|
||||
val aliceResult: Pair<TimestampedWireTransaction, LedgerTransaction> = aliceFuture.get()
|
||||
val bobResult: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture.get()
|
||||
val bobsWallet = listOf<StateAndRef<Cash.State>>(lookup("bob cash1"), lookup("bob cash2"))
|
||||
|
||||
assertEquals(aliceResult, bobResult)
|
||||
tpSeller.runSeller(
|
||||
addr2,
|
||||
lookup("alice's paper"),
|
||||
1000.DOLLARS,
|
||||
ALICE_KEY,
|
||||
TEST_KEYS_TO_CORP_MAP,
|
||||
DUMMY_TIMESTAMPER
|
||||
)
|
||||
tpBuyer.runBuyer(
|
||||
addr1,
|
||||
1000.DOLLARS,
|
||||
CommercialPaper.State::class.java,
|
||||
bobsWallet,
|
||||
mapOf(BOB to BOB_KEY.private),
|
||||
DUMMY_TIMESTAMPER,
|
||||
TEST_KEYS_TO_CORP_MAP
|
||||
)
|
||||
|
||||
txns.add(aliceResult.second)
|
||||
// Everything is on this thread so we can now step through the protocol one step at a time.
|
||||
// Seller Alice already sent a message to Buyer Bob. Pump once:
|
||||
node2.pump(false)
|
||||
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
|
||||
val storageBob = smmBuyer.saveToBytes()
|
||||
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.
|
||||
node2.stop()
|
||||
|
||||
// Alice doesn't know that and sends Bob the now finalised transaction. Alice sends a message to a node
|
||||
// that has gone offline.
|
||||
node1.pump(false)
|
||||
|
||||
// ... bring the network back up ...
|
||||
node2 = network.createNodeWithID(true, addr2.id).start().get()
|
||||
|
||||
// We must provide the state machines with all the stuff that couldn't be saved to disk.
|
||||
var bobFuture: ListenableFuture<Pair<TimestampedWireTransaction, LedgerTransaction>>? = null
|
||||
fun resumeStateMachine(forObj: ProtocolStateMachine<*,*>): Any {
|
||||
return when (forObj) {
|
||||
is TwoPartyTradeProtocol.Buyer -> {
|
||||
bobFuture = forObj.resultFuture
|
||||
return TwoPartyTradeProtocol.BuyerContext(bobsWallet, mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP, null)
|
||||
}
|
||||
else -> fail()
|
||||
}
|
||||
}
|
||||
// The act of constructing this object will re-register the message handlers that Bob was waiting on before
|
||||
// the reboot occurred.
|
||||
StateMachineManager(node2, MoreExecutors.directExecutor(), storageBob, ::resumeStateMachine)
|
||||
assertTrue(node2.pump(false))
|
||||
// Bob is now finished and has the same transaction as Alice.
|
||||
val tx: Pair<TimestampedWireTransaction, LedgerTransaction> = bobFuture!!.get()
|
||||
txns.add(tx.second)
|
||||
verify()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user