Trading: move session ID into the initial args to avoid possible confusion at the start.

This commit is contained in:
Mike Hearn 2015-12-14 18:22:00 +01:00
parent f15e24e7be
commit d43bbe8faa
2 changed files with 23 additions and 13 deletions

View File

@ -45,14 +45,16 @@ abstract class TwoPartyTradeProtocol {
class SellerInitialArgs( class SellerInitialArgs(
val assetToSell: StateAndRef<OwnableState>, val assetToSell: StateAndRef<OwnableState>,
val price: Amount, val price: Amount,
val myKeyPair: KeyPair val myKeyPair: KeyPair,
val buyerSessionID: Long
) )
abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller
class BuyerInitialArgs( class BuyerInitialArgs(
val acceptablePrice: Amount, val acceptablePrice: Amount,
val typeToBuy: Class<out OwnableState> val typeToBuy: Class<out OwnableState>,
val sessionID: Long
) )
abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer
@ -95,8 +97,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
// Make the first message we'll send to kick off the protocol. // Make the first message we'll send to kick off the protocol.
val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID) val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID)
// Zero is a special session ID that is being listened to by the buyer (i.e. before a session is started). val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, args.buyerSessionID, sessionID, hello)
val partialTX = sendAndReceive<SignedWireTransaction>(TRADE_TOPIC, 0, sessionID, hello)
logger().trace { "Received partially signed transaction" } logger().trace { "Received partially signed transaction" }
partialTX.verifySignatures() partialTX.verifySignatures()
@ -124,7 +125,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(serviceHub.timestampingService) val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(serviceHub.timestampingService)
logger().trace { "Built finished transaction, sending back to secondary!" } logger().trace { "Built finished transaction, sending back to secondary!" }
send(TRADE_TOPIC, sessionID, timestamped) send(TRADE_TOPIC, args.buyerSessionID, timestamped)
return Pair(timestamped, timestamped.verifyToLedgerTransaction(serviceHub.timestampingService, serviceHub.identityService)) return Pair(timestamped, timestamped.verifyToLedgerTransaction(serviceHub.timestampingService, serviceHub.identityService))
} }
@ -138,8 +139,8 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
// The buyer's side of the protocol. See note above Seller to learn about the caveats here. // The buyer's side of the protocol. See note above Seller to learn about the caveats here.
class BuyerImpl : Buyer() { class BuyerImpl : Buyer() {
override fun call(args: BuyerInitialArgs): Pair<TimestampedWireTransaction, LedgerTransaction> { override fun call(args: BuyerInitialArgs): Pair<TimestampedWireTransaction, LedgerTransaction> {
// Wait for a trade request to come in on special session ID zero. // Wait for a trade request to come in on our pre-provided session ID.
val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, 0) val tradeRequest = receive<SellerTradeInfo>(TRADE_TOPIC, args.sessionID)
// What is the seller trying to sell us? // What is the seller trying to sell us?
val assetTypeName = tradeRequest.assetForSale.state.javaClass.name val assetTypeName = tradeRequest.assetForSale.state.javaClass.name
@ -183,10 +184,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
logger().trace { "Sending partially signed transaction to seller" } logger().trace { "Sending partially signed transaction to seller" }
// We'll just reuse the session ID the seller selected here for convenience.
// TODO: Protect against the buyer terminating here and leaving us in the lurch without the final tx. // TODO: Protect against the buyer terminating here and leaving us in the lurch without the final tx.
// TODO: Protect against a malicious buyer sending us back a different transaction to the one we built.
val fullySigned = sendAndReceive<TimestampedWireTransaction>(TRADE_TOPIC, val fullySigned = sendAndReceive<TimestampedWireTransaction>(TRADE_TOPIC,
tradeRequest.sessionID, tradeRequest.sessionID, stx) tradeRequest.sessionID, args.sessionID, stx)
logger().trace { "Got fully signed transaction, verifying ... "} logger().trace { "Got fully signed transaction, verifying ... "}

View File

@ -20,6 +20,7 @@ import core.testutils.*
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.security.SecureRandom
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.logging.Formatter import java.util.logging.Formatter
import java.util.logging.Level import java.util.logging.Level
@ -76,19 +77,23 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread)) val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread))
val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread)) val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread))
val buyerSessionID = SecureRandom.getInstanceStrong().nextLong()
val aliceResult = tpSeller.runSeller( val aliceResult = tpSeller.runSeller(
bobsAddress, bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs( TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"), lookup("alice's paper"),
1000.DOLLARS, 1000.DOLLARS,
ALICE_KEY ALICE_KEY,
buyerSessionID
) )
) )
val bobResult = tpBuyer.runBuyer( val bobResult = tpBuyer.runBuyer(
alicesAddress, alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs( TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS, 1000.DOLLARS,
CommercialPaper.State::class.java CommercialPaper.State::class.java,
buyerSessionID
) )
) )
@ -129,19 +134,23 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor()) val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor())
val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer) val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer)
val buyerSessionID = SecureRandom.getInstanceStrong().nextLong()
tpSeller.runSeller( tpSeller.runSeller(
bobsAddress, bobsAddress,
TwoPartyTradeProtocol.SellerInitialArgs( TwoPartyTradeProtocol.SellerInitialArgs(
lookup("alice's paper"), lookup("alice's paper"),
1000.DOLLARS, 1000.DOLLARS,
ALICE_KEY ALICE_KEY,
buyerSessionID
) )
) )
tpBuyer.runBuyer( tpBuyer.runBuyer(
alicesAddress, alicesAddress,
TwoPartyTradeProtocol.BuyerInitialArgs( TwoPartyTradeProtocol.BuyerInitialArgs(
1000.DOLLARS, 1000.DOLLARS,
CommercialPaper.State::class.java CommercialPaper.State::class.java,
buyerSessionID
) )
) )