Trading: move session ID into the initial args to avoid possible confusion at the start.

This commit is contained in:
Mike Hearn 2015-12-14 18:22:00 +01:00
parent f15e24e7be
commit d43bbe8faa
2 changed files with 23 additions and 13 deletions

View File

@ -45,14 +45,16 @@ abstract class TwoPartyTradeProtocol {
class SellerInitialArgs(
val assetToSell: StateAndRef<OwnableState>,
val price: Amount,
val myKeyPair: KeyPair
val myKeyPair: KeyPair,
val buyerSessionID: Long
)
abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller
class BuyerInitialArgs(
val acceptablePrice: Amount,
val typeToBuy: Class<out OwnableState>
val typeToBuy: Class<out OwnableState>,
val sessionID: Long
)
abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer
@ -95,8 +97,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
// Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID)
// Zero is a special session ID that is being listened to by the buyer (i.e. before a session is started).
val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, 0, sessionID, hello)
val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, args.buyerSessionID, sessionID, hello)
logger().trace { "Received partially signed transaction" }
partialTX.verifySignatures()
@ -124,7 +125,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(serviceHub.timestampingService)
logger().trace { "Built finished transaction, sending back to secondary!" }
send(TRADE_TOPIC, sessionID, timestamped)
send(TRADE_TOPIC, args.buyerSessionID, timestamped)
return Pair(timestamped, timestamped.verifyToLedgerTransaction(serviceHub.timestampingService, serviceHub.identityService))
}
@ -138,8 +139,8 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
// The buyer's side of the protocol. See note above Seller to learn about the caveats here.
class BuyerImpl : Buyer() {
override fun call(args: BuyerInitialArgs): Pair<TimestampedWireTransaction, LedgerTransaction> {
// Wait for a trade request to come in on special session ID zero.
val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, 0)
// Wait for a trade request to come in on our pre-provided session ID.
val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, args.sessionID)
// What is the seller trying to sell us?
val assetTypeName = tradeRequest.assetForSale.state.javaClass.name
@ -183,10 +184,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
logger().trace { "Sending partially signed transaction to seller" }
// We'll just reuse the session ID the seller selected here for convenience.
// TODO: Protect against the buyer terminating here and leaving us in the lurch without the final tx.
// TODO: Protect against a malicious buyer sending us back a different transaction to the one we built.
val fullySigned = sendAndReceive<TimestampedWireTransaction>(TRADE_TOPIC,
tradeRequest.sessionID, tradeRequest.sessionID, stx)
tradeRequest.sessionID, args.sessionID, stx)
logger().trace { "Got fully signed transaction, verifying ... "}

View File

@ -20,6 +20,7 @@ import core.testutils.*
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.security.SecureRandom
import java.util.concurrent.Executors
import java.util.logging.Formatter
import java.util.logging.Level
@ -76,19 +77,23 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread))
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread))
val buyerSessionID = SecureRandom.getInstanceStrong().nextLong()
val aliceResult = tpSeller.runSeller(
bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY
ALICE_KEY,
buyerSessionID
)
)
val bobResult = tpBuyer.runBuyer(
alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS,
CommercialPaper.State::class.java
CommercialPaper.State::class.java,
buyerSessionID
)
)
@ -129,19 +134,23 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor())
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer)
val buyerSessionID = SecureRandom.getInstanceStrong().nextLong()
tpSeller.runSeller(
bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"),
1000.DOLLARS,
ALICE_KEY
ALICE_KEY,
buyerSessionID
)
)
tpBuyer.runBuyer(
alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS,
CommercialPaper.State::class.java
CommercialPaper.State::class.java,
buyerSessionID
)
)