Protocols: simplify the two party (dvp) protocol some more, now that we've switched to Quasar. There's no longer any need to define InitialArgs objects.

This commit is contained in:
Mike Hearn 2016-01-08 17:26:02 +01:00
parent c59603c26f
commit 78849f44d2
5 changed files with 106 additions and 119 deletions

View File

@ -9,6 +9,7 @@
package contracts.protocols
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import contracts.Cash
import contracts.sumCashBy
import core.*
@ -23,8 +24,6 @@ import java.security.KeyPair
import java.security.PublicKey
import java.time.Instant
// TODO: Get rid of the "initial args" concept and just use the class c'tors, now we are using Quasar.
/**
* This asset trading protocol implements a "delivery vs payment" type swap. It has two parties (B and S for buyer
* and seller) and the following steps:
@ -49,64 +48,46 @@ import java.time.Instant
*
* To see an example of how to use this class, look at the unit tests.
*/
abstract class TwoPartyTradeProtocol {
class SellerInitialArgs(
val assetToSell: StateAndRef<OwnableState>,
val price: Amount,
val myKeyPair: KeyPair,
val buyerSessionID: Long
)
object TwoPartyTradeProtocol {
val TRADE_TOPIC = "com.r3cev.protocols.trade"
abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller
class BuyerInitialArgs(
val acceptablePrice: Amount,
val typeToBuy: Class<out OwnableState>,
val sessionID: Long
)
abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer
abstract class Buyer : ProtocolStateMachine<BuyerInitialArgs, Pair<WireTransaction, LedgerTransaction>>()
abstract class Seller : ProtocolStateMachine<SellerInitialArgs, Pair<WireTransaction, LedgerTransaction>>()
companion object {
@JvmStatic fun create(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode): TwoPartyTradeProtocol {
return TwoPartyTradeProtocolImpl(smm, timestampingAuthority)
}
}
}
private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
private val timestampingAuthority: LegallyIdentifiableNode) : TwoPartyTradeProtocol() {
companion object {
val TRADE_TOPIC = "com.r3cev.protocols.trade"
fun runSeller(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode,
otherSide: SingleMessageRecipient, assetToSell: StateAndRef<OwnableState>, price: Amount,
myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture<Pair<WireTransaction, LedgerTransaction>> {
val seller = Seller(otherSide, timestampingAuthority, assetToSell, price, myKeyPair, buyerSessionID)
smm.add("$TRADE_TOPIC.seller", seller)
return seller.resultFuture
}
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount,
val sellerOwnerKey: PublicKey,
val sessionID: Long
)
fun runBuyer(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode,
otherSide: SingleMessageRecipient, acceptablePrice: Amount, typeToBuy: Class<out OwnableState>,
sessionID: Long): ListenableFuture<Pair<WireTransaction, LedgerTransaction>> {
val buyer = Buyer(otherSide, timestampingAuthority.identity, acceptablePrice, typeToBuy, sessionID)
smm.add("$TRADE_TOPIC.buyer", buyer)
return buyer.resultFuture
}
class SellerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: LegallyIdentifiableNode) : Seller() {
class Seller(val otherSide: SingleMessageRecipient,
val timestampingAuthority: LegallyIdentifiableNode,
val assetToSell: StateAndRef<OwnableState>,
val price: Amount,
val myKeyPair: KeyPair,
val buyerSessionID: Long) : ProtocolStateMachine<Pair<WireTransaction, LedgerTransaction>>() {
@Suspendable
override fun call(args: SellerInitialArgs): Pair<WireTransaction, LedgerTransaction> {
override fun call(): Pair<WireTransaction, LedgerTransaction> {
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID)
val hello = SellerTradeInfo(assetToSell, price, myKeyPair.public, sessionID)
val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, otherSide, args.buyerSessionID, sessionID, hello)
val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, otherSide, buyerSessionID, sessionID, hello)
logger.trace { "Received partially signed transaction" }
partialTX.verifySignatures()
val wtx: WireTransaction = partialTX.txBits.deserialize()
requireThat {
"transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(args.myKeyPair.public) == args.price)
"transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(myKeyPair.public) == price)
// There are all sorts of funny games a malicious secondary might play here, we should fix them:
//
// - This tx may attempt to send some assets we aren't intending to sell to the secondary, if
@ -122,7 +103,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
// Sign with our key and get the timestamping authorities key as well.
// These two steps could be done in parallel, in theory.
val ourSignature = args.myKeyPair.signWithECDSA(partialTX.txBits)
val ourSignature = myKeyPair.signWithECDSA(partialTX.txBits)
val tsaSig = TimestamperClient(this, timestampingAuthority).timestamp(partialTX.txBits)
val fullySigned = partialTX.withAdditionalSignature(tsaSig).withAdditionalSignature(ourSignature)
@ -130,23 +111,36 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
logger.trace { "Built finished transaction, sending back to secondary!" }
send(TRADE_TOPIC, otherSide, args.buyerSessionID, fullySigned)
send(TRADE_TOPIC, otherSide, buyerSessionID, fullySigned)
return Pair(wtx, fullySigned.verifyToLedgerTransaction(serviceHub.identityService))
}
}
class UnacceptablePriceException(val givenPrice: Amount) : Exception()
class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
// This object is serialised to the network and is the first protocol message the seller sends to the buyer.
private class SellerTradeInfo(
val assetForSale: StateAndRef<OwnableState>,
val price: Amount,
val sellerOwnerKey: PublicKey,
val sessionID: Long
)
private class UnacceptablePriceException(val givenPrice: Amount) : Exception()
private class AssetMismatchException(val expectedTypeName: String, val typeName: String) : Exception() {
override fun toString() = "The submitted asset didn't match the expected type: $expectedTypeName vs $typeName"
}
// The buyer's side of the protocol. See note above Seller to learn about the caveats here.
class BuyerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: Party) : Buyer() {
class Buyer(val otherSide: SingleMessageRecipient,
val timestampingAuthority: Party,
val acceptablePrice: Amount,
val typeToBuy: Class<out OwnableState>,
val sessionID: Long) : ProtocolStateMachine<Pair<WireTransaction, LedgerTransaction>>() {
@Suspendable
override fun call(args: BuyerInitialArgs): Pair<WireTransaction, LedgerTransaction> {
override fun call(): Pair<WireTransaction, LedgerTransaction> {
// Wait for a trade request to come in on our pre-provided session ID.
val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, args.sessionID)
val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, sessionID)
// What is the seller trying to sell us?
val assetTypeName = tradeRequest.assetForSale.state.javaClass.name
@ -154,10 +148,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
// Check the start message for acceptability.
check(tradeRequest.sessionID > 0)
if (tradeRequest.price > args.acceptablePrice)
if (tradeRequest.price > acceptablePrice)
throw UnacceptablePriceException(tradeRequest.price)
if (!args.typeToBuy.isInstance(tradeRequest.assetForSale.state))
throw AssetMismatchException(args.typeToBuy.name, assetTypeName)
if (!typeToBuy.isInstance(tradeRequest.assetForSale.state))
throw AssetMismatchException(typeToBuy.name, assetTypeName)
// TODO: Either look up the stateref here in our local db, or accept a long chain of states and
// validate them to audit the other side and ensure it actually owns the state we are being offered!
@ -198,7 +192,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
// TODO: Protect against a malicious buyer sending us back a different transaction to the one we built.
val fullySigned = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, otherSide, tradeRequest.sessionID,
args.sessionID, stx)
sessionID, stx)
logger.trace { "Got fully signed transaction, verifying ... "}
@ -209,12 +203,4 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager,
return Pair(fullySigned.tx, ltx)
}
}
override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller {
return smm.add(args, "$TRADE_TOPIC.seller", SellerImpl(otherSide, timestampingAuthority))
}
override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer {
return smm.add(args, "$TRADE_TOPIC.buyer", BuyerImpl(otherSide, timestampingAuthority.identity))
}
}

View File

@ -29,6 +29,7 @@ import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.ByteArrayOutputStream
import java.util.*
import java.util.concurrent.Callable
import java.util.concurrent.Executor
import javax.annotation.concurrent.ThreadSafe
@ -57,10 +58,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private val _stateMachines = Collections.synchronizedList(ArrayList<ProtocolStateMachine<*,*>>())
private val _stateMachines = Collections.synchronizedList(ArrayList<ProtocolStateMachine<*>>())
/** Returns a snapshot of the currently registered state machines. */
val stateMachines: List<ProtocolStateMachine<*,*>> get() {
val stateMachines: List<ProtocolStateMachine<*>> get() {
synchronized(_stateMachines) {
return ArrayList(_stateMachines)
}
@ -110,10 +111,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
}
}
private fun deserializeFiber(bits: ByteArray): ProtocolStateMachine<*, *> {
private fun deserializeFiber(bits: ByteArray): ProtocolStateMachine<*> {
val deserializer = Fiber.getFiberSerializer() as KryoSerializer
val kryo = createKryo(deserializer.kryo)
val psm = kryo.readClassAndObject(Input(bits)) as ProtocolStateMachine<*, *>
val psm = kryo.readClassAndObject(Input(bits)) as ProtocolStateMachine<*>
return psm
}
@ -123,9 +124,9 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T : ProtocolStateMachine<I, *>, I> add(initialArgs: I, loggerName: String, fiber: T): T {
fun <T : ProtocolStateMachine<*>> add(loggerName: String, fiber: T): T {
val logger = LoggerFactory.getLogger(loggerName)
iterateStateMachine(fiber, serviceHub.networkService, logger, initialArgs, null) {
iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) {
it.start()
}
return fiber
@ -141,8 +142,8 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
return key
}
private fun iterateStateMachine(psm: ProtocolStateMachine<*, *>, net: MessagingService, logger: Logger,
obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*, *>) -> Unit) {
private fun iterateStateMachine(psm: ProtocolStateMachine<*>, net: MessagingService, logger: Logger,
obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*>) -> Unit) {
val onSuspend = fun(request: FiberRequest, serFiber: ByteArray) {
// We have a request to do something: send, receive, or send-and-receive.
if (request is FiberRequest.ExpectingResponse<*>) {
@ -181,7 +182,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
}
}
private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*,*>,
private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*>,
responseType: Class<*>, topic: String, prevCheckpointKey: SecureHash?,
serialisedFiber: ByteArray) {
val checkpoint = Checkpoint(serialisedFiber, logger.name, topic, responseType.name)
@ -201,8 +202,9 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor())
/**
* The base class that should be used by any object that wishes to act as a protocol state machine. The type variable
* C is the type of the initial arguments. R is the type of the return.
* The base class that should be used by any object that wishes to act as a protocol state machine. A PSM is
* a kind of "fiber", and a fiber in turn is a bit like a thread, but a thread that can be suspended to the heap,
* serialised to disk, and resumed on demand.
*
* Sub-classes should override the [call] method and return whatever the final result of the protocol is. Inside the
* call method, the rules of normal object oriented programming are a little different:
@ -216,11 +218,15 @@ object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler"
* via the [serviceHub] property which is provided. Don't try and keep data you got from a service across calls to
* send/receive/sendAndReceive because the world might change in arbitrary ways out from underneath you, for instance,
* if the node is restarted or reconfigured!
* - Don't pass initial data in using a constructor. This object will be instantiated using reflection so you cannot
* define your own constructor. Instead define a separate class that holds your initial arguments, and take it as
* the argument to [call].
*
* Note that the result of the [call] method can be obtained in a couple of different ways. One is to call the get
* method, as the PSM is a [Future]. But that will block the calling thread until the result is ready, which may not
* be what you want (unless you know it's finished already). So you can also use the [resultFuture] property, which is
* a [ListenableFuture] and will let you register a callback.
*
* Once created, a PSM should be passed to a [StateMachineManager] which will start it and manage its execution.
*/
abstract class ProtocolStateMachine<C, R> : Fiber<R>("protocol", SameThreadFiberScheduler) {
abstract class ProtocolStateMachine<R> : Fiber<R>("protocol", SameThreadFiberScheduler), Callable<R> {
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient private var suspendFunc: ((result: FiberRequest, serFiber: ByteArray) -> Unit)? = null
@Transient private var resumeWithObject: Any? = null
@ -245,12 +251,12 @@ abstract class ProtocolStateMachine<C, R> : Fiber<R>("protocol", SameThreadFiber
this.serviceHub = serviceHub
}
@Suspendable
abstract fun call(args: C): R
// This line may look useless, but it's needed to convince the Quasar bytecode rewriter to do the right thing.
@Suspendable override abstract fun call(): R
@Suspendable @Suppress("UNCHECKED_CAST")
override fun run(): R {
val result = call(resumeWithObject as C)
val result = call()
if (result != null)
(resultFuture as SettableFuture<R>).set(result)
return result

View File

@ -107,7 +107,7 @@ class TimestamperNodeService(private val net: MessagingService,
}
@ThreadSafe
class TimestamperClient(private val psm: ProtocolStateMachine<*, *>, private val node: LegallyIdentifiableNode) : TimestamperService {
class TimestamperClient(private val psm: ProtocolStateMachine<*>, private val node: LegallyIdentifiableNode) : TimestamperService {
override val identity: Party = node.identity
@Suspendable

View File

@ -71,32 +71,29 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
net = bobsNode
)
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread), timestamper)
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread), timestamper)
val buyerSessionID = random63BitValue()
val aliceResult = tpSeller.runSeller(
val aliceResult = TwoPartyTradeProtocol.runSeller(
StateMachineManager(alicesServices, backgroundThread),
timestamper,
bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val bobResult = tpBuyer.runBuyer(
val bobResult = TwoPartyTradeProtocol.runBuyer(
StateMachineManager(bobsServices, backgroundThread),
timestamper,
alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
assertEquals(aliceResult.resultFuture.get(), bobResult.resultFuture.get())
assertEquals(aliceResult.get(), bobResult.get())
txns.add(aliceResult.resultFuture.get().second)
txns.add(aliceResult.get().second)
verify()
}
backgroundThread.shutdown()
@ -128,28 +125,26 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
storage = bobsStorage
)
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, MoreExecutors.directExecutor()), timestamper.first)
val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor())
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer, timestamper.first)
val buyerSessionID = random63BitValue()
tpSeller.runSeller(
TwoPartyTradeProtocol.runSeller(
StateMachineManager(alicesServices, MoreExecutors.directExecutor()),
timestamper.first,
bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
tpBuyer.runBuyer(
TwoPartyTradeProtocol.runBuyer(
smmBuyer,
timestamper.first,
alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
// Everything is on this thread so we can now step through the protocol one step at a time.

View File

@ -54,9 +54,9 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
service = TimestamperNodeService(serviceNode.second, Party("Unit test suite", ALICE), ALICE_KEY)
}
class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine<Any?, Boolean>() {
class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine<Boolean>() {
@Suspendable
override fun call(args: Any?): Boolean {
override fun call(): Boolean {
val client = TimestamperClient(this, server)
val ptx = TransactionBuilder().apply {
addInputState(ContractStateRef(SecureHash.randomSHA256(), 0))
@ -77,7 +77,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
val smm = StateMachineManager(MockServices(net = myNode.second), RunOnCallerThread)
val logName = TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC
val psm = TestPSM(myNode.second.networkMap.timestampingNodes[0], clock.instant())
smm.add(serviceNode.first, logName, psm)
smm.add(logName, psm)
psm
}
assertTrue(psm.isDone)