Protocol frameworks: separate the fiber object from the logic object to make it easier to compose subprotocols together.

This commit is contained in:
Mike Hearn 2016-02-16 19:06:45 +01:00
parent dc520392b8
commit 299e1af15e
5 changed files with 122 additions and 86 deletions

@ -18,7 +18,7 @@ import core.crypto.SecureHash
import core.crypto.signWithECDSA
import core.messaging.*
import core.node.DataVendingService
import core.node.TimestamperClient
import core.node.TimestampingProtocol
import core.utilities.trace
import java.security.KeyPair
import java.security.PublicKey
@ -55,16 +55,14 @@ object TwoPartyTradeProtocol {
otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>, price: Amount,
myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture<SignedTransaction> {
val seller = Seller(otherSide, timestampingAuthority, assetToSell, price, myKeyPair, buyerSessionID)
smm.add("$TRADE_TOPIC.seller", seller)
return seller.resultFuture
return smm.add("$TRADE_TOPIC.seller", seller)
}
fun runBuyer(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode,
otherSide: SingleMessageRecipient, acceptablePrice: Amount, typeToBuy: Class<out OwnableState>,
sessionID: Long): ListenableFuture<SignedTransaction> {
val buyer = Buyer(otherSide, timestampingAuthority.identity, acceptablePrice, typeToBuy, sessionID)
smm.add("$TRADE_TOPIC.buyer", buyer)
return buyer.resultFuture
return smm.add("$TRADE_TOPIC.buyer", buyer)
}
class UnacceptablePriceException(val givenPrice: Amount) : Exception()
@ -88,14 +86,14 @@ object TwoPartyTradeProtocol {
val assetToSell: StateAndRef<OwnableState>,
val price: Amount,
val myKeyPair: KeyPair,
val buyerSessionID: Long) : ProtocolStateMachine<SignedTransaction>() {
val buyerSessionID: Long) : ProtocolLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val partialTX: SignedTransaction = receiveAndCheckProposedTransaction()
// These two steps could be done in parallel, in theory. Our framework doesn't support that yet though.
val ourSignature = signWithOurKey(partialTX)
val tsaSig = timestamp(partialTX)
val tsaSig = subProtocol(TimestampingProtocol(timestampingAuthority, partialTX.txBits))
val signedTransaction = sendSignatures(partialTX, ourSignature, tsaSig)
@ -103,7 +101,7 @@ object TwoPartyTradeProtocol {
}
@Suspendable
open fun receiveAndCheckProposedTransaction(): SignedTransaction {
private fun receiveAndCheckProposedTransaction(): SignedTransaction {
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
@ -137,7 +135,7 @@ object TwoPartyTradeProtocol {
}
@Suspendable
open fun checkDependencies(txToCheck: SignedTransaction) {
private fun checkDependencies(txToCheck: SignedTransaction) {
val toVerify = HashSet<LedgerTransaction>()
val alreadyVerified = HashSet<LedgerTransaction>()
val downloadedSignedTxns = ArrayList<SignedTransaction>()
@ -249,15 +247,10 @@ object TwoPartyTradeProtocol {
}
}
open fun signWithOurKey(partialTX: SignedTransaction) = myKeyPair.signWithECDSA(partialTX.txBits)
private fun signWithOurKey(partialTX: SignedTransaction) = myKeyPair.signWithECDSA(partialTX.txBits)
@Suspendable
open fun timestamp(partialTX: SignedTransaction): DigitalSignature.LegallyIdentifiable {
return TimestamperClient(this, timestampingAuthority).timestamp(partialTX.txBits)
}
@Suspendable
open fun sendSignatures(partialTX: SignedTransaction, ourSignature: DigitalSignature.WithKey,
private fun sendSignatures(partialTX: SignedTransaction, ourSignature: DigitalSignature.WithKey,
tsaSig: DigitalSignature.LegallyIdentifiable): SignedTransaction {
val fullySigned = partialTX + tsaSig + ourSignature
@ -272,7 +265,8 @@ object TwoPartyTradeProtocol {
val timestampingAuthority: Party,
val acceptablePrice: Amount,
val typeToBuy: Class<out OwnableState>,
val sessionID: Long) : ProtocolStateMachine<SignedTransaction>() {
val sessionID: Long) : ProtocolLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val tradeRequest = receiveAndValidateTradeRequest()
@ -289,9 +283,9 @@ object TwoPartyTradeProtocol {
}
@Suspendable
open fun receiveAndValidateTradeRequest(): SellerTradeInfo {
private fun receiveAndValidateTradeRequest(): SellerTradeInfo {
// Wait for a trade request to come in on our pre-provided session ID.
val maybeTradeRequest = receive(TRADE_TOPIC, sessionID, SellerTradeInfo::class.java)
val maybeTradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, sessionID)
val tradeRequest = maybeTradeRequest.validate {
// What is the seller trying to sell us?
@ -315,15 +309,15 @@ object TwoPartyTradeProtocol {
}
@Suspendable
open fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller {
private fun swapSignaturesWithSeller(stx: SignedTransaction, theirSessionID: Long): SignaturesFromSeller {
logger.trace { "Sending partially signed transaction to seller" }
// TODO: Protect against the seller terminating here and leaving us in the lurch without the final tx.
return sendAndReceive(TRADE_TOPIC, otherSide, theirSessionID, sessionID, stx, SignaturesFromSeller::class.java).validate { it }
return sendAndReceive<SignaturesFromSeller>(TRADE_TOPIC, otherSide, theirSessionID, sessionID, stx).validate { it }
}
open fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
private fun signWithOurKeys(cashSigningPubKeys: List<PublicKey>, ptx: TransactionBuilder): SignedTransaction {
// Now sign the transaction with whatever keys we need to move the cash.
for (k in cashSigningPubKeys) {
val priv = serviceHub.keyManagementService.toPrivate(k)
@ -338,7 +332,7 @@ object TwoPartyTradeProtocol {
return stx
}
open fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> {
private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> {
val ptx = TransactionBuilder()
// Add input and output states for the movement of cash, by using the Cash contract to generate the states.
val wallet = serviceHub.walletService.currentWallet

@ -32,7 +32,6 @@ import java.io.ByteArrayOutputStream
import java.io.PrintWriter
import java.io.StringWriter
import java.util.*
import java.util.concurrent.Callable
import java.util.concurrent.Executor
import javax.annotation.concurrent.ThreadSafe
@ -62,7 +61,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private val _stateMachines = Collections.synchronizedList(ArrayList<ProtocolStateMachine<*>>())
private val _stateMachines = Collections.synchronizedList(ArrayList<ProtocolLogic<*>>())
// This is a workaround for something Gradle does to us during unit tests. It replaces stderr with its own
// class that inserts itself into a ThreadLocal. That then gets caught in fiber serialisation, which we don't
@ -73,10 +72,11 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
// ever recover.
val checkpointing: Boolean get() = !System.err.javaClass.name.contains("LinePerThreadBufferingOutputStream")
/** Returns a snapshot of the currently registered state machines. */
val stateMachines: List<ProtocolStateMachine<*>> get() {
/** Returns a list of all state machines executing the given protocol logic at the top level (subprotocols do not count) */
fun <T> findStateMachines(klass: Class<out ProtocolLogic<T>>): List<Pair<ProtocolLogic<T>, ListenableFuture<T>>> {
synchronized(_stateMachines) {
return ArrayList(_stateMachines)
@Suppress("UNCHECKED_CAST")
return _stateMachines.filterIsInstance(klass).map { it to (it.psm as ProtocolStateMachine<T>).resultFuture }
}
}
@ -113,7 +113,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
// Grab the Kryo engine configured by Quasar for its own stuff, and then do our own configuration on top
// so we can deserialised the nested stream that holds the fiber.
val psm = deserializeFiber(checkpoint.serialisedFiber)
_stateMachines.add(psm)
_stateMachines.add(psm.logic)
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
val topic = checkpoint.awaitingTopic
@ -155,12 +155,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T : ProtocolStateMachine<*>> add(loggerName: String, fiber: T): T {
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
val logger = LoggerFactory.getLogger(loggerName)
val fiber = ProtocolStateMachine(logic)
iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) {
it.start()
}
return fiber
return fiber.resultFuture
}
private fun persistCheckpoint(prevCheckpointKey: SecureHash?, new: ByteArray): SecureHash {
@ -206,7 +207,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
// We're back! Check if the fiber is finished and if so, clean up.
if (psm.isTerminated) {
_stateMachines.remove(psm)
_stateMachines.remove(psm.logic)
checkpointsMap.remove(prevCheckpointKey)
}
}
@ -236,35 +237,83 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor())
/**
* The base class that should be used by any object that wishes to act as a protocol state machine. A PSM is
* a kind of "fiber", and a fiber in turn is a bit like a thread, but a thread that can be suspended to the heap,
* serialised to disk, and resumed on demand.
* A sub-class of [ProtocolLogic<T>] implements a protocol flow using direct, straight line blocking code. Thus you
* can write complex protocol logic in an ordinary fashion, without having to think about callbacks, restarting after
* a node crash, how many instances of your protocol there are running and so on.
*
* Sub-classes should override the [call] method and return whatever the final result of the protocol is. Inside the
* call method, the rules of normal object oriented programming are a little different:
* Invoking the network will cause the call stack to be suspended onto the heap and then serialized to a database using
* the Quasar fibers framework. Because of this, if you need access to data that might change over time, you should
* request it just-in-time via the [serviceHub] property which is provided. Don't try and keep data you got from a
* service across calls to send/receive/sendAndReceive because the world might change in arbitrary ways out from
* underneath you, for instance, if the node is restarted or reconfigured!
*
* - You can call send/receive/sendAndReceive in order to suspend the state machine and request network interaction.
* This does not block a thread and when a state machine is suspended like this, it will be serialised and written
* to stable storage. That means all objects on the stack and referenced from fields must be serialisable as well
* (with Kryo, so they don't have to implement the Java Serializable interface). The state machine may be resumed
* at some arbitrary later point.
* - Because of this, if you need access to data that might change over time, you should request it just-in-time
* via the [serviceHub] property which is provided. Don't try and keep data you got from a service across calls to
* send/receive/sendAndReceive because the world might change in arbitrary ways out from underneath you, for instance,
* if the node is restarted or reconfigured!
* Additionally, be aware of what data you pin either via the stack or in your [ProtocolLogic] implementation. Very large
* objects or datasets will hurt performance by increasing the amount of data stored in each checkpoint.
*
* The result of the [call] method can be obtained by using the [resultFuture] property, which is a [ListenableFuture]
* and will let you register a callback to be informed when the protocol has completed. Note that the PSM class is also
* a future, but not a listenable one.
*
* Once created, a PSM should be passed to a [StateMachineManager] which will start it and manage its execution.
* If you'd like to use another ProtocolLogic class as a component of your own, construct it on the fly and then pass
* it to the [subProtocol] method. It will return the result of that protocol when it completes.
*/
abstract class ProtocolStateMachine<R> : Fiber<R>("protocol", SameThreadFiberScheduler), Callable<R> {
abstract class ProtocolLogic<T> {
/** Reference to the [Fiber] instance that is the top level controller for the entire flow. */
lateinit var psm: ProtocolStateMachine<*>
/** This is where you should log things to. */
val logger: Logger get() = psm.logger
/** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */
val serviceHub: ServiceHub get() = psm.serviceHub
// Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long,
sessionIDForReceive: Long, obj: Any): UntrustworthyData<T> {
return psm.sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, obj, T::class.java)
}
inline fun <reified T : Any> receive(topic: String, sessionIDForReceive: Long): UntrustworthyData<T> {
return psm.receive(topic, sessionIDForReceive, T::class.java)
}
@Suspendable fun send(topic: String, destination: MessageRecipients, sessionID: Long, obj: Any) {
psm.send(topic, destination, sessionID, obj)
}
/**
* Invokes the given subprotocol by simply passing through this [ProtocolLogic]s reference to the
* [ProtocolStateMachine] and then calling the [call] method.
*/
@Suspendable fun <R> subProtocol(subLogic: ProtocolLogic<R>): R {
subLogic.psm = psm
return subLogic.call()
}
@Suspendable
abstract fun call(): T
}
/**
* A ProtocolStateMachine instance is a suspendable fiber that delegates all actual logic to a [ProtocolLogic] instance.
* For any given flow there is only one PSM, even if that protocol invokes subprotocols.
*
* These classes are created by the [StateMachineManager] when a new protocol is started at the topmost level. If
* a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost
* logic element gets to return the value that the entire state machine resolves to.
*/
class ProtocolStateMachine<R>(val logic: ProtocolLogic<R>) : Fiber<R>("protocol", SameThreadFiberScheduler) {
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient private var suspendFunc: ((result: FiberRequest, serFiber: ByteArray) -> Unit)? = null
@Transient private var resumeWithObject: Any? = null
@Transient lateinit var serviceHub: ServiceHub
@Transient protected lateinit var logger: Logger
@Transient lateinit var logger: Logger
init {
logic.psm = this
}
fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?, logger: Logger,
suspendFunc: (FiberRequest, ByteArray) -> Unit) {
this.suspendFunc = suspendFunc
this.logger = logger
this.resumeWithObject = withObject
this.serviceHub = serviceHub
}
@Transient private var _resultFuture: SettableFuture<R>? = SettableFuture.create<R>()
/** This future will complete when the call method returns. */
@ -276,21 +325,10 @@ abstract class ProtocolStateMachine<R> : Fiber<R>("protocol", SameThreadFiberSch
}
}
fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?, logger: Logger,
suspendFunc: (FiberRequest, ByteArray) -> Unit) {
this.suspendFunc = suspendFunc
this.logger = logger
this.resumeWithObject = withObject
this.serviceHub = serviceHub
}
// This line may look useless, but it's needed to convince the Quasar bytecode rewriter to do the right thing.
@Suspendable override abstract fun call(): R
@Suspendable @Suppress("UNCHECKED_CAST")
override fun run(): R {
try {
val result = call()
val result = logic.call()
if (result != null)
_resultFuture?.set(result)
return result
@ -335,15 +373,6 @@ abstract class ProtocolStateMachine<R> : Fiber<R>("protocol", SameThreadFiberSch
val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, obj)
Fiber.parkAndSerialize { fiber, writer -> suspendFunc!!(result, writer.write(fiber)) }
}
// Kotlin helpers that allow the use of generic types.
inline fun <reified T : Any> sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long,
sessionIDForReceive: Long, obj: Any): UntrustworthyData<T> {
return sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, obj, T::class.java)
}
inline fun <reified T : Any> receive(topic: String, sessionIDForReceive: Long): UntrustworthyData<T> {
return receive(topic, sessionIDForReceive, T::class.java)
}
}
/**

@ -13,10 +13,7 @@ import co.paralleluniverse.fibers.Suspendable
import core.*
import core.crypto.DigitalSignature
import core.crypto.signWithECDSA
import core.messaging.LegallyIdentifiableNode
import core.messaging.MessageRecipients
import core.messaging.MessagingService
import core.messaging.ProtocolStateMachine
import core.messaging.*
import core.serialization.SerializedBytes
import core.serialization.deserialize
import core.serialization.serialize
@ -95,7 +92,6 @@ class TimestamperNodeService(private val net: MessagingService,
}
}
@ThreadSafe
class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val node: LegallyIdentifiableNode) : TimestamperService {
override val identity: Party = node.identity
@ -116,3 +112,22 @@ class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val no
}
}
class TimestampingProtocol(private val node: LegallyIdentifiableNode,
private val wtxBytes: SerializedBytes<WireTransaction>) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
@Suspendable
override fun call(): DigitalSignature.LegallyIdentifiable {
val sessionID = random63BitValue()
val replyTopic = "${TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID"
val req = TimestampingMessages.Request(wtxBytes, serviceHub.networkService.myAddress, replyTopic)
val maybeSignature = sendAndReceive<DigitalSignature.LegallyIdentifiable>(
TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req)
// Check that the timestamping authority gave us back a valid signature and didn't break somehow
maybeSignature.validate { sig ->
sig.verifyWithECDSA(wtxBytes)
return sig
}
}
}

@ -168,8 +168,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
)
// Find the future representing the result of this state machine again.
assertEquals(1, smm.stateMachines.size)
var bobFuture = smm.stateMachines.filterIsInstance<TwoPartyTradeProtocol.Buyer>().first().resultFuture
var bobFuture = smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second
// Let Bob process his mailbox.
assertTrue(bobsNode.pump(false))
@ -179,7 +178,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
txns.add(stx.tx)
verify()
assertTrue(smm.stateMachines.isEmpty())
assertTrue(smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).isEmpty())
}
}
@ -239,7 +238,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
@Test
fun `dependency with error`() {
transactionGroupFor<ContractState> {
val (bobsWallet, fakeTxns) = fillUp(withError = true)
val bobsWallet = fillUp(withError = true).first
val (alicesAddress, alicesNode) = makeNode(inBackground = true)
val (bobsAddress, bobsNode) = makeNode(inBackground = true)

@ -58,10 +58,10 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
service = TimestamperNodeService(serviceNode.second, Party("Unit test suite", ALICE), ALICE_KEY)
}
class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine<Boolean>() {
class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolLogic<Boolean>() {
@Suspendable
override fun call(): Boolean {
val client = TimestamperClient(this, server)
val client = TimestamperClient(psm, server)
val ptx = TransactionBuilder().apply {
addInputState(StateRef(SecureHash.randomSHA256(), 0))
addOutputState(100.DOLLARS.CASH)
@ -82,7 +82,6 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
val logName = TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC
val psm = TestPSM(mockServices.networkMapService.timestampingNodes[0], clock.instant())
smm.add(logName, psm)
psm
}
assertTrue(psm.isDone)
}