diff --git a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt index cba636f959..b5ecb7c01c 100644 --- a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt +++ b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt @@ -16,20 +16,23 @@ import core.messaging.* import core.serialization.deserialize import core.utilities.trace import java.security.KeyPair -import java.security.PrivateKey import java.security.PublicKey import java.security.SecureRandom /** * This asset trading protocol has two parties (B and S for buyer and seller) and the following steps: * - * 1. B sends the [StateAndRef] pointing to what they want to sell to S, along with info about the price. - * 2. S sends to B a [SignedWireTransaction] that includes the state as input, S's cash as input, the state with the new - * owner key as output, and any change cash as output. It contains a single signature from S but isn't valid because - * it lacks a signature from B authorising movement of the asset. - * 3. B signs it and hands the now finalised SignedWireTransaction back to S. + * 1. S sends the [StateAndRef] pointing to what they want to sell to B, along with info about the price they require + * B to pay. For example this has probably been agreed on an exchange. + * 2. B sends to S a [SignedWireTransaction] that includes the state as input, B's cash as input, the state with the new + * owner key as output, and any change cash as output. It contains a single signature from B but isn't valid because + * it lacks a signature from S authorising movement of the asset. + * 3. S signs it and hands the now finalised SignedWireTransaction back to B. * - * They both end the protocol being in posession of a validly signed contract. + * Assuming no malicious termination, they both end the protocol being in posession of a valid, signed transaction + * that represents an atomic asset swap. + * + * Note that it's the *seller* who initiates contact with the buyer, not vice-versa as you might imagine. * * To get an implementation of this class, use the static [TwoPartyTradeProtocol.create] method. Then use either * the [runBuyer] or [runSeller] methods, depending on which side of the trade your node is taking. These methods @@ -40,58 +43,27 @@ import java.security.SecureRandom * To see an example of how to use this class, look at the unit tests. */ abstract class TwoPartyTradeProtocol { - // TODO: Replace some args with the context objects - abstract fun runSeller( - otherSide: SingleMessageRecipient, - assetToSell: StateAndRef, - price: Amount, - myKey: KeyPair, - partyKeyMap: Map, - timestamper: TimestamperService - ): ListenableFuture> - - abstract fun runBuyer( - otherSide: SingleMessageRecipient, - acceptablePrice: Amount, - typeToSell: Class, - wallet: List>, - myKeys: Map, - timestamper: TimestamperService, - partyKeyMap: Map - ): ListenableFuture> - - class BuyerInitialArgs( - val acceptablePrice: Amount, - val typeToSell: String - ) - - class BuyerContext( - val wallet: List>, - val myKeys: Map, - val timestamper: TimestamperService, - val partyKeyMap: Map, - val initialArgs: BuyerInitialArgs? - ) - - // This wraps some of the arguments passed to runSeller that are persistent across the lifetime of the trade and - // can be serialised. class SellerInitialArgs( val assetToSell: StateAndRef, val price: Amount, val myKeyPair: KeyPair ) - // This wraps the things which the seller needs, but which might change whilst the continuation is suspended, - // e.g. due to a VM restart, networking issue, configuration file reload etc. It also contains the initial args - // and the future that the code will fill out when done. - class SellerContext( - val timestamper: TimestamperService, - val partyKeyMap: Map, - val initialArgs: SellerInitialArgs? + abstract fun runSeller(otherSide: SingleMessageRecipient, + args: SellerInitialArgs): ListenableFuture> + + class BuyerInitialArgs( + val acceptablePrice: Amount, + val typeToBuy: Class ) - abstract class Buyer : ProtocolStateMachine>() - abstract class Seller : ProtocolStateMachine>() + abstract fun runBuyer( + otherSide: SingleMessageRecipient, + args: BuyerInitialArgs + ): ListenableFuture> + + abstract class Buyer : ProtocolStateMachine>() + abstract class Seller : ProtocolStateMachine>() companion object { @JvmStatic fun create(smm: StateMachineManager): TwoPartyTradeProtocol { @@ -100,6 +72,7 @@ abstract class TwoPartyTradeProtocol { } } +/** The implementation of the [TwoPartyTradeProtocol] base class. */ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : TwoPartyTradeProtocol() { companion object { val TRADE_TOPIC = "com.r3cev.protocols.trade" @@ -117,22 +90,22 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled // by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be // trying to poke at is different to the one we saw at compile time, so we'd get ClassCastExceptions. All - // interaction with this class must be through either interfaces, or objects passed to and from the continuation - // by the state machine framework. Please refer to the documentation website (docs/build/html) to learn more about - // the protocol state machine framework. + // interaction with this class must be through either interfaces, the supertype, or objects passed to and from + // the continuation by the state machine framework. Please refer to the documentation website (docs/build/html) to + // learn more about the protocol state machine framework. class SellerImpl : Seller() { - override fun call(): Pair { + override fun call(args: SellerInitialArgs): Pair { val sessionID = makeSessionID() - val args = context().initialArgs!! + // Make the first message we'll send to kick off the protocol. val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID) - // Zero is a special session ID that is used to start a trade (i.e. before a session is started). - var (ctx2, offerMsg) = sendAndReceive(TRADE_TOPIC, 0, sessionID, hello) + + // Zero is a special session ID that is being listened to by the buyer (i.e. before a session is started). + val partialTX = sendAndReceive(TRADE_TOPIC, 0, sessionID, hello) logger().trace { "Received partially signed transaction" } - val partialTx = offerMsg - partialTx.verifySignatures() - val wtx = partialTx.txBits.deserialize() + partialTX.verifySignatures() + val wtx = partialTX.txBits.deserialize() requireThat { "transaction sends us the right amount of cash" by (wtx.outputStates.sumCashBy(args.myKeyPair.public) == args.price) @@ -149,14 +122,16 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // express protocol state machines on top of the messaging layer. } - val ourSignature = args.myKeyPair.signWithECDSA(partialTx.txBits.bits) - val fullySigned: SignedWireTransaction = partialTx.copy(sigs = partialTx.sigs + ourSignature) + val ourSignature = args.myKeyPair.signWithECDSA(partialTX.txBits.bits) + val fullySigned: SignedWireTransaction = partialTX.copy(sigs = partialTX.sigs + ourSignature) // We should run it through our full TransactionGroup of all transactions here. fullySigned.verify() - val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(ctx2.timestamper) + val timestamped: TimestampedWireTransaction = fullySigned.toTimestampedTransaction(serviceHub.timestampingService) logger().trace { "Built finished transaction, sending back to secondary!" } + send(TRADE_TOPIC, sessionID, timestamped) - return Pair(timestamped, timestamped.verifyToLedgerTransaction(ctx2.timestamper, ctx2.partyKeyMap)) + + return Pair(timestamped, timestamped.verifyToLedgerTransaction(serviceHub.timestampingService, serviceHub.identityService)) } } @@ -167,75 +142,72 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // The buyer's side of the protocol. See note above Seller to learn about the caveats here. class BuyerImpl : Buyer() { - override fun call(): Pair { - val acceptablePrice = context().initialArgs!!.acceptablePrice - val typeToSell = context().initialArgs!!.typeToSell - // Start a new scope here so we can't accidentally reuse 'ctx' after doing the sendAndReceive below, - // as the context object we're meant to use might change each time we suspend (e.g. due to VM restart). - val (stx, theirSessionID) = run { - // Wait for a trade request to come in. - val (ctx, tradeRequest) = receive(TRADE_TOPIC, 0) - val assetTypeName = tradeRequest.assetForSale.state.javaClass.name + override fun call(args: BuyerInitialArgs): Pair { + // Wait for a trade request to come in on special session ID zero. + val tradeRequest = receive(TRADE_TOPIC, 0) - logger().trace { "Got trade request for a $assetTypeName" } + // What is the seller trying to sell us? + val assetTypeName = tradeRequest.assetForSale.state.javaClass.name + logger().trace { "Got trade request for a $assetTypeName" } - // Check the start message for acceptability. - check(tradeRequest.sessionID > 0) - if (tradeRequest.price > acceptablePrice) - throw UnacceptablePriceException(tradeRequest.price) - if (!Class.forName(typeToSell).isInstance(tradeRequest.assetForSale.state)) - throw AssetMismatchException(typeToSell, assetTypeName) + // Check the start message for acceptability. + check(tradeRequest.sessionID > 0) + if (tradeRequest.price > args.acceptablePrice) + throw UnacceptablePriceException(tradeRequest.price) + if (!args.typeToBuy.isInstance(tradeRequest.assetForSale.state)) + throw AssetMismatchException(args.typeToBuy.name, assetTypeName) - // TODO: Either look up the stateref here in our local db, or accept a long chain of states and - // validate them to audit the other side and ensure it actually owns the state we are being offered! - // For now, just assume validity! + // TODO: Either look up the stateref here in our local db, or accept a long chain of states and + // validate them to audit the other side and ensure it actually owns the state we are being offered! + // For now, just assume validity! - // Generate the shared transaction that both sides will sign, using the data we have. - val ptx = PartialTransaction() - // Add input and output states for the movement of cash. - val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, ctx.wallet) - // Add inputs/outputs/a command for the movement of the asset. - ptx.addInputState(tradeRequest.assetForSale.ref) - // Just pick some arbitrary public key for now (this provides poor privacy). - val (command, state) = tradeRequest.assetForSale.state.withNewOwner(ctx.myKeys.keys.first()) - ptx.addOutputState(state) - ptx.addArg(WireCommand(command, tradeRequest.assetForSale.state.owner)) + // Generate the shared transaction that both sides will sign, using the data we have. + val ptx = PartialTransaction() + // Add input and output states for the movement of cash, by using the Cash contract to generate the states. + val wallet = serviceHub.walletService.currentWallet + val cashStates = wallet.statesOfType() + val cashSigningPubKeys = Cash().craftSpend(ptx, tradeRequest.price, tradeRequest.sellerOwnerKey, cashStates) + // Add inputs/outputs/a command for the movement of the asset. + ptx.addInputState(tradeRequest.assetForSale.ref) + // Just pick some new public key for now. + val freshKey = serviceHub.keyManagementService.freshKey() + val (command, state) = tradeRequest.assetForSale.state.withNewOwner(freshKey.public) + ptx.addOutputState(state) + ptx.addArg(WireCommand(command, tradeRequest.assetForSale.state.owner)) - for (k in cashSigningPubKeys) { - // TODO: This error case should be removed through the introduction of a Wallet class. - val priv = ctx.myKeys[k] ?: throw IllegalStateException("Coin in wallet with no known privkey") - ptx.signWith(KeyPair(k, priv)) - } - - val stx = ptx.toSignedTransaction(checkSufficientSignatures = false) - stx.verifySignatures() // Verifies that we generated a signed transaction correctly. - Pair(stx, tradeRequest.sessionID) + // Now sign the transaction with whatever keys we need to move the cash. + for (k in cashSigningPubKeys) { + val priv = serviceHub.keyManagementService.toPrivate(k) + ptx.signWith(KeyPair(k, priv)) } - // TODO: Could run verify() here to make sure the only signature missing is the primaries. + val stx = ptx.toSignedTransaction(checkSufficientSignatures = false) + stx.verifySignatures() // Verifies that we generated a signed transaction correctly. + + // TODO: Could run verify() here to make sure the only signature missing is the sellers. + logger().trace { "Sending partially signed transaction to seller" } + // We'll just reuse the session ID the seller selected here for convenience. - val (ctx, fullySigned) = sendAndReceive(TRADE_TOPIC, theirSessionID, theirSessionID, stx) + // TODO: Protect against the buyer terminating here and leaving us in the lurch without the final tx. + val fullySigned = sendAndReceive(TRADE_TOPIC, + tradeRequest.sessionID, tradeRequest.sessionID, stx) + logger().trace { "Got fully signed transaction, verifying ... "} - val ltx = fullySigned.verifyToLedgerTransaction(ctx.timestamper, ctx.partyKeyMap) + + val ltx = fullySigned.verifyToLedgerTransaction(serviceHub.timestampingService, serviceHub.identityService) + logger().trace { "Fully signed transaction was valid. Trade complete! :-)" } + return Pair(fullySigned, ltx) } } - override fun runSeller(otherSide: SingleMessageRecipient, assetToSell: StateAndRef, - price: Amount, myKey: KeyPair, partyKeyMap: Map, - timestamper: TimestamperService): ListenableFuture> { - val args = SellerInitialArgs(assetToSell, price, myKey) - val context = SellerContext(timestamper, partyKeyMap, args) - return smm.add(otherSide, context, "$TRADE_TOPIC.seller", SellerImpl::class.java) + override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): ListenableFuture> { + return smm.add(otherSide, args, "$TRADE_TOPIC.seller", SellerImpl::class.java) } - override fun runBuyer(otherSide: SingleMessageRecipient, acceptablePrice: Amount, - typeToSell: Class, wallet: List>, - myKeys: Map, timestamper: TimestamperService, - partyKeyMap: Map): ListenableFuture> { - val context = BuyerContext(wallet, myKeys, timestamper, partyKeyMap, BuyerInitialArgs(acceptablePrice, typeToSell.name)) - return smm.add(otherSide, context, "$TRADE_TOPIC.buyer", BuyerImpl::class.java) + override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): ListenableFuture> { + return smm.add(otherSide, args, "$TRADE_TOPIC.buyer", BuyerImpl::class.java) } } \ No newline at end of file diff --git a/src/main/kotlin/core/Services.kt b/src/main/kotlin/core/Services.kt index d92966cd3b..c8f185bd64 100644 --- a/src/main/kotlin/core/Services.kt +++ b/src/main/kotlin/core/Services.kt @@ -8,6 +8,8 @@ package core +import core.messaging.MessagingSystem +import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey import java.time.Instant @@ -18,12 +20,15 @@ import java.time.Instant */ /** - * A wallet (name may be temporary) wraps a set of private keys, and a set of states that are known about and that can - * be influenced by those keys, for instance, because we own them. This class represents an immutable, stable state - * of a wallet: it is guaranteed not to change out from underneath you, even though the canonical currently-best-known - * wallet may change as we learn new transactions from our peers. + * A wallet (name may be temporary) wraps a set of states that are useful for us to keep track of, for instance, + * because we own them. This class represents an immutable, stable state of a wallet: it is guaranteed not to + * change out from underneath you, even though the canonical currently-best-known wallet may change as we learn + * about new transactions from our peers and generate new transactiont that consume states ourselves. */ -data class Wallet(val states: List>, val keys: Map) +data class Wallet(val states: List>) { + @Suppress("UNCHECKED_CAST") + inline fun statesOfType() = states.filter { it.state is T } as List> +} /** * A [WalletService] is responsible for securely and safely persisting the current state of a wallet to storage. The @@ -32,16 +37,36 @@ data class Wallet(val states: List>, val keys: Map + + fun toPrivate(publicKey: PublicKey) = keys[publicKey] ?: throw IllegalStateException("No private key known for requested public key") + + /** Generates a new random key and adds it to the exposed map. */ + fun freshKey(): KeyPair +} + /** * An identity service maintains an bidirectional map of [Party]s to their associated public keys and thus supports * lookup of a party given its key. This is obviously very incomplete and does not reflect everything a real identity * service would provide. */ interface IdentityService { - fun partyFromKey(key: PublicKey): Party + fun partyFromKey(key: PublicKey): Party? } /** @@ -56,6 +81,15 @@ interface TimestamperService { fun verifyTimestamp(hash: SecureHash, signedTimestamp: ByteArray): Instant } +/** + * A sketch of an interface to a simple key/value storage system. Intended for persistence of simple blobs like + * transactions, serialised protocol state machines and so on. Again, this isn't intended to imply lack of SQL or + * anything like that, this interface is only big enough to support the prototyping work. + */ +interface StorageService { + fun getMap(tableName: String): MutableMap +} + /** * A service hub simply vends references to the other services a node has. Some of those services may be missing or * mocked out. This class is useful to pass to chunks of pluggable code that might have need of many different kinds of @@ -63,6 +97,9 @@ interface TimestamperService { */ interface ServiceHub { val walletService: WalletService + val keyManagementService: KeyManagementService val identityService: IdentityService val timestampingService: TimestamperService + val storageService: StorageService + val networkService: MessagingSystem // TODO: Rename class to be consistent. } \ No newline at end of file diff --git a/src/main/kotlin/core/Transactions.kt b/src/main/kotlin/core/Transactions.kt index 8ec475c3d8..ed402d1707 100644 --- a/src/main/kotlin/core/Transactions.kt +++ b/src/main/kotlin/core/Transactions.kt @@ -58,9 +58,9 @@ data class WireTransaction(val inputStates: List, val commands: List) { fun serializeForSignature(): ByteArray = serialize() - fun toLedgerTransaction(timestamp: Instant?, partyKeyMap: Map, originalHash: SecureHash): LedgerTransaction { + fun toLedgerTransaction(timestamp: Instant?, identityService: IdentityService, originalHash: SecureHash): LedgerTransaction { val authenticatedArgs = commands.map { - val institutions = it.pubkeys.mapNotNull { pk -> partyKeyMap[pk] } + val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) } AuthenticatedObject(it.pubkeys, institutions, it.command) } return LedgerTransaction(inputStates, outputStates, authenticatedArgs, timestamp, originalHash) @@ -191,11 +191,11 @@ data class TimestampedWireTransaction( ) { val transactionID: SecureHash = serialize().sha256() - fun verifyToLedgerTransaction(timestamper: TimestamperService, partyKeyMap: Map): LedgerTransaction { + fun verifyToLedgerTransaction(timestamper: TimestamperService, identityService: IdentityService): LedgerTransaction { val stx: SignedWireTransaction = signedWireTX.deserialize() val wtx: WireTransaction = stx.verify() val instant: Instant? = if (timestamp != null) timestamper.verifyTimestamp(signedWireTX.sha256(), timestamp.bits) else null - return wtx.toLedgerTransaction(instant, partyKeyMap, transactionID) + return wtx.toLedgerTransaction(instant, identityService, transactionID) } } diff --git a/src/main/kotlin/core/messaging/Messaging.kt b/src/main/kotlin/core/messaging/Messaging.kt index a3d8972b36..74b1e4c4ca 100644 --- a/src/main/kotlin/core/messaging/Messaging.kt +++ b/src/main/kotlin/core/messaging/Messaging.kt @@ -9,9 +9,7 @@ package core.messaging import com.google.common.util.concurrent.ListenableFuture -import core.serialization.deserialize import core.serialization.serialize -import java.time.Duration import java.time.Instant import java.util.concurrent.Executor import javax.annotation.concurrent.ThreadSafe @@ -38,9 +36,6 @@ interface MessagingSystem { * The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler]. * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister * itself and yet addMessageHandler hasn't returned the handle yet. - * - * If the callback throws an exception then the message is discarded and will not be retried, unless the exception - * is a subclass of [RetryMessageLaterException], in which case the message will be queued and attempted later. */ fun addMessageHandler(topic: String = "", executor: Executor? = null, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration @@ -86,19 +81,6 @@ fun MessagingSystem.runOnNextMessage(topic: String = "", executor: Executor? = n fun MessagingSystem.send(topic: String, to: MessageRecipients, obj: Any) = send(createMessage(topic, obj.serialize()), to) -/** - * Registers a handler for the given topic that runs the given callback with the message content deserialised to the - * given type, and then removes itself. - */ -inline fun MessagingSystem.runOnNextMessageWith(topic: String = "", - executor: Executor? = null, - noinline callback: (T) -> Unit) { - addMessageHandler(topic, executor) { msg, reg -> - callback(msg.data.deserialize()) - removeMessageHandler(reg) - } -} - /** * This class lets you start up a [MessagingSystem]. Its purpose is to stop you from getting access to the methods * on the messaging system interface until you have successfully started up the system. One of these objects should @@ -114,11 +96,6 @@ interface MessagingSystemBuilder { interface MessageHandlerRegistration -class RetryMessageLaterException : Exception() { - /** If set, the message will be re-queued and retried after the requested interval. */ - var delayPeriod: Duration? = null -} - /** * A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in * Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt index cc9b9cdba7..6ccc08d031 100644 --- a/src/main/kotlin/core/messaging/StateMachines.kt +++ b/src/main/kotlin/core/messaging/StateMachines.kt @@ -11,6 +11,8 @@ package core.messaging import com.esotericsoftware.kryo.io.Input import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture +import core.SecureHash +import core.ServiceHub import core.serialization.THREAD_LOCAL_KRYO import core.serialization.createKryo import core.serialization.deserialize @@ -23,17 +25,25 @@ import org.objenesis.strategy.InstantiatorStrategy import org.slf4j.Logger import org.slf4j.LoggerFactory import java.util.* -import java.util.concurrent.Callable import java.util.concurrent.Executor /** * A StateMachineManager is responsible for coordination and persistence of multiple [ProtocolStateMachine] objects. + * Each such object represents an instantiation of a (two-party) protocol that has reached a particular point. * * An implementation of this class will persist state machines to long term storage so they can survive process restarts * and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other * (bad for performance, good for programmer mental health!). + * + * TODO: The framework should do automatic error handling. */ -class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { +class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) { + private val checkpointsMap = serviceHub.storageService.getMap("state machines") + private val _stateMachines: MutableList> = ArrayList() + + /** Returns a snapshot of the currently registered state machines. */ + val stateMachines: List> get() = ArrayList(_stateMachines) + // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). private class Checkpoint( val continuation: Continuation, @@ -43,13 +53,17 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { val awaitingObjectOfType: String // java class name ) - constructor(net: MessagingSystem, runInThread: Executor, restoreCheckpoints: List, resumeStateMachine: (ProtocolStateMachine<*,*>) -> Any) : this(net, runInThread) { - for (bytes in restoreCheckpoints) { + init { + restoreCheckpoints() + } + + private fun restoreCheckpoints() { + for (bytes in checkpointsMap.values) { val kryo = createKryo() // Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode // rewriting is performed correctly. - var psm: ProtocolStateMachine<*,*>? = null + var psm: ProtocolStateMachine<*, *>? = null kryo.instantiatorStrategy = object : InstantiatorStrategy { val forwardingTo = kryo.instantiatorStrategy @@ -58,7 +72,9 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { // The messing around with types we do here confuses the compiler/IDE a bit and it warns us. @Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS") return ObjectInstantiator { - psm = loadContinuationClass(type as Class>).first + val p = loadContinuationClass(type as Class>).first + p.serviceHub = serviceHub + psm = p psm as T } } else { @@ -69,55 +85,47 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { val checkpoint = bytes.deserialize(kryo) val continuation = checkpoint.continuation - val transientContext = resumeStateMachine(psm!!) + _stateMachines.add(psm!!) val logger = LoggerFactory.getLogger(checkpoint.loggerName) val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType) // The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the // StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time. - setupNextMessageHandler(logger, net, continuation, checkpoint.otherSide, awaitingObjectOfType, - checkpoint.awaitingTopic, transientContext, bytes) + setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide, + awaitingObjectOfType, checkpoint.awaitingTopic, bytes) } } - fun add(otherSide: MessageRecipients, transientContext: Any, loggerName: String, continuationClass: Class>): ListenableFuture { + fun add(otherSide: MessageRecipients, initialArgs: T, loggerName: String, + continuationClass: Class>): ListenableFuture { val logger = LoggerFactory.getLogger(loggerName) - val (sm, continuation) = loadContinuationClass(continuationClass) + val (sm, continuation) = loadContinuationClass(continuationClass) + sm.serviceHub = serviceHub + _stateMachines.add(sm) runInThread.execute { // The current state of the continuation is held in the closure attached to the messaging system whenever // the continuation suspends and tells us it expects a response. - iterateStateMachine(continuation, net, otherSide, transientContext, transientContext, logger, null) + iterateStateMachine(continuation, serviceHub.networkService, otherSide, initialArgs, logger, null) } - return sm.resultFuture + @Suppress("UNCHECKED_CAST") + return (sm as ProtocolStateMachine).resultFuture } @Suppress("UNCHECKED_CAST") - private fun loadContinuationClass(continuationClass: Class>): Pair, Continuation> { + private fun loadContinuationClass(continuationClass: Class>): Pair, Continuation> { val url = continuationClass.protectionDomain.codeSource.location val cl = ContinuationClassLoader(arrayOf(url), this.javaClass.classLoader) - val obj = cl.forceLoadClass(continuationClass.name).newInstance() as ProtocolStateMachine<*, R> + val obj = cl.forceLoadClass(continuationClass.name).newInstance() as ProtocolStateMachine<*, *> return Pair(obj, Continuation.startSuspendedWith(obj)) } - private val checkpoints: LinkedList = LinkedList() private fun persistCheckpoint(prev: ByteArray?, new: ByteArray) { - synchronized(checkpoints) { - if (prev == null) { - for (i in checkpoints.size - 1 downTo 0) { - val b = checkpoints[i] - if (Arrays.equals(b, prev)) { - checkpoints[i] = new - return - } - } - } - checkpoints.add(new) - } + if (prev != null) + checkpointsMap.remove(SecureHash.sha256(prev)) + checkpointsMap[SecureHash.sha256(new)] = new } - fun saveToBytes(): LinkedList = synchronized(checkpoints) { LinkedList(checkpoints) } - private fun iterateStateMachine(c: Continuation, net: MessagingSystem, otherSide: MessageRecipients, - transientContext: Any, continuationInput: Any?, logger: Logger, + continuationInput: Any?, logger: Logger, prevPersistedBytes: ByteArray?): Continuation { // This will resume execution of the run() function inside the continuation at the place it left off. val oldLogger = CONTINUATION_LOGGER.get() @@ -141,7 +149,7 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { if (req is ContinuationResult.ExpectingResponse<*>) { // Prepare a listener on the network that runs in the background thread when we received a message. val topic = "${req.topic}.${req.sessionIDForReceive}" - setupNextMessageHandler(logger, net, nextState, otherSide, req.responseType, topic, transientContext, prevPersistedBytes) + setupNextMessageHandler(logger, net, nextState, otherSide, req.responseType, topic, prevPersistedBytes) } // If an object to send was provided (not null), send it now. req.obj?.let { @@ -151,7 +159,7 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { } if (req is ContinuationResult.NotExpectingResponse) { // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - return iterateStateMachine(nextState, net, otherSide, transientContext, transientContext, logger, prevPersistedBytes) + return iterateStateMachine(nextState, net, otherSide, null, logger, prevPersistedBytes) } else { return nextState } @@ -159,13 +167,14 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { private fun setupNextMessageHandler(logger: Logger, net: MessagingSystem, nextState: Continuation, otherSide: MessageRecipients, responseType: Class<*>, - topic: String, transientContext: Any, prevPersistedBytes: ByteArray?) { + topic: String, prevPersistedBytes: ByteArray?) { val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name) - persistCheckpoint(prevPersistedBytes, checkpoint.serialize()) + val curPersistedBytes = checkpoint.serialize() + persistCheckpoint(prevPersistedBytes, curPersistedBytes) net.runOnNextMessage(topic, runInThread) { netMsg -> val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType) logger.trace { "<- $topic : message of type ${obj.javaClass.name}" } - iterateStateMachine(nextState, net, otherSide, transientContext, Pair(transientContext, obj), logger, prevPersistedBytes) + iterateStateMachine(nextState, net, otherSide, obj, logger, curPersistedBytes) } } } @@ -173,45 +182,57 @@ class StateMachineManager(val net: MessagingSystem, val runInThread: Executor) { val CONTINUATION_LOGGER = ThreadLocal() /** - * A convenience mixin interface that can be implemented by an object that will act as a continuation. + * The base class that should be used by any object that wishes to act as a protocol state machine. Sub-classes should + * override the [call] method and return whatever the final result of the protocol is. Inside the call method, + * the rules of normal object oriented programming are a little different: * - * A ProtocolStateMachine must implement the run method from [Runnable], and the rest of what this interface - * provides are pre-defined utility methods to ease implementation of such machines. + * - You can call send/receive/sendAndReceive in order to suspend the state machine and request network interaction. + * This does not block a thread and when a state machine is suspended like this, it will be serialised and written + * to stable storage. That means all objects on the stack and referenced from fields must be serialisable as well + * (with Kryo, so they don't have to implement the Java Serializable interface). The state machine may be resumed + * at some arbitrary later point. + * - Because of this, if you need access to data that might change over time, you should request it just-in-time + * via the [serviceHub] property which is provided. Don't try and keep data you got from a service across calls to + * send/receive/sendAndReceive because the world might change in arbitrary ways out from underneath you, for instance, + * if the node is restarted or reconfigured! + * - Don't pass initial data in using a constructor. This object will be instantiated using reflection so you cannot + * define your own constructor. Instead define a separate class that holds your initial arguments, and take it as + * the argument to [call]. */ @Suppress("UNCHECKED_CAST") -abstract class ProtocolStateMachine : Callable, Runnable { - protected fun context(): CONTEXT_TYPE = Continuation.getContext() as CONTEXT_TYPE +abstract class ProtocolStateMachine : Runnable { protected fun logger(): Logger = CONTINUATION_LOGGER.get() // These fields shouldn't be serialised. @Transient private var _resultFuture: SettableFuture = SettableFuture.create() - val resultFuture: ListenableFuture get() = _resultFuture + @Transient lateinit var serviceHub: ServiceHub + + abstract fun call(args: T): R override fun run() { - val r = call() + val r = call(Continuation.getContext() as T) if (r != null) _resultFuture.set(r) } } @Suppress("NOTHING_TO_INLINE", "UNCHECKED_CAST") -inline fun ProtocolStateMachine.send(topic: String, sessionID: Long, obj: S) = - Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) as CONTEXT_TYPE +inline fun ProtocolStateMachine<*, *>.send(topic: String, sessionID: Long, obj: S) = + Continuation.suspend(ContinuationResult.NotExpectingResponse(topic, sessionID, obj)) @Suppress("UNCHECKED_CAST") -inline fun ProtocolStateMachine.sendAndReceive( - topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any): Pair { +inline fun ProtocolStateMachine<*, *>.sendAndReceive( + topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any): R { return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive, - obj, R::class.java)) as Pair + obj, R::class.java)) as R } @Suppress("UNCHECKED_CAST") -inline fun ProtocolStateMachine.receive( - topic: String, sessionIDForReceive: Long): Pair { - return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null, - R::class.java)) as Pair +inline fun ProtocolStateMachine<*, *>.receive( + topic: String, sessionIDForReceive: Long): R { + return Continuation.suspend(ContinuationResult.ExpectingResponse(topic, -1, sessionIDForReceive, null, R::class.java)) as R } open class ContinuationResult(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) { diff --git a/src/test/kotlin/contracts/CommercialPaperTests.kt b/src/test/kotlin/contracts/CommercialPaperTests.kt index ed9040eba6..a527653346 100644 --- a/src/test/kotlin/contracts/CommercialPaperTests.kt +++ b/src/test/kotlin/contracts/CommercialPaperTests.kt @@ -107,7 +107,7 @@ class CommercialPaperTests { val ptx = CommercialPaper().craftIssue(MINI_CORP.ref(123), 10000.DOLLARS, TEST_TX_TIME + 30.days) ptx.signWith(MINI_CORP_KEY) val stx = ptx.toSignedTransaction() - stx.verify().toLedgerTransaction(TEST_TX_TIME, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + stx.verify().toLedgerTransaction(TEST_TX_TIME, MockIdentityService, SecureHash.randomSHA256()) } val (alicesWalletTX, alicesWallet) = cashOutputsToWallet( @@ -124,7 +124,7 @@ class CommercialPaperTests { ptx.signWith(MINI_CORP_KEY) ptx.signWith(ALICE_KEY) val stx = ptx.toSignedTransaction() - stx.verify().toLedgerTransaction(TEST_TX_TIME, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + stx.verify().toLedgerTransaction(TEST_TX_TIME, MockIdentityService, SecureHash.randomSHA256()) } // Won't be validated. @@ -138,7 +138,7 @@ class CommercialPaperTests { CommercialPaper().craftRedeem(ptx, moveTX.outRef(1), corpWallet) ptx.signWith(ALICE_KEY) ptx.signWith(MINI_CORP_KEY) - return ptx.toSignedTransaction().verify().toLedgerTransaction(time, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + return ptx.toSignedTransaction().verify().toLedgerTransaction(time, MockIdentityService, SecureHash.randomSHA256()) } val tooEarlyRedemption = makeRedeemTX(TEST_TX_TIME + 10.days) diff --git a/src/test/kotlin/contracts/CrowdFundTests.kt b/src/test/kotlin/contracts/CrowdFundTests.kt index 6efb208ab7..866e491b8b 100644 --- a/src/test/kotlin/contracts/CrowdFundTests.kt +++ b/src/test/kotlin/contracts/CrowdFundTests.kt @@ -105,7 +105,7 @@ class CrowdFundTests { val ptx = CrowdFund().craftRegister(MINI_CORP.ref(123), 1000.DOLLARS, "crowd funding", TEST_TX_TIME + 7.days) ptx.signWith(MINI_CORP_KEY) val stx = ptx.toSignedTransaction() - stx.verify().toLedgerTransaction(TEST_TX_TIME, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + stx.verify().toLedgerTransaction(TEST_TX_TIME, MockIdentityService, SecureHash.randomSHA256()) } // let's give Alice some funds that she can invest @@ -123,7 +123,7 @@ class CrowdFundTests { ptx.signWith(ALICE_KEY) val stx = ptx.toSignedTransaction() // this verify passes - the transaction contains an output cash, necessary to verify the fund command - stx.verify().toLedgerTransaction(TEST_TX_TIME, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + stx.verify().toLedgerTransaction(TEST_TX_TIME, MockIdentityService, SecureHash.randomSHA256()) } // Won't be validated. @@ -137,7 +137,7 @@ class CrowdFundTests { CrowdFund().craftClose(ptx, pledgeTX.outRef(0), miniCorpWallet) ptx.signWith(MINI_CORP_KEY) val stx = ptx.toSignedTransaction() - return stx.verify().toLedgerTransaction(time, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + return stx.verify().toLedgerTransaction(time, MockIdentityService, SecureHash.randomSHA256()) } val tooEarlyClose = makeFundedTX(TEST_TX_TIME + 6.days) diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index e0c7f316d8..846685a43f 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -8,12 +8,14 @@ package core.messaging -import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors import contracts.Cash import contracts.CommercialPaper import contracts.protocols.TwoPartyTradeProtocol -import core.* +import core.ContractState +import core.DOLLARS +import core.StateAndRef +import core.days import core.testutils.* import org.junit.After import org.junit.Before @@ -25,13 +27,12 @@ import java.util.logging.LogRecord import java.util.logging.Logger import kotlin.test.assertEquals import kotlin.test.assertTrue -import kotlin.test.fail /** - * In this example, Alessia wishes to sell her commercial paper to Boris in return for $1,000,000 and they wish to do + * In this example, Alice wishes to sell her commercial paper to Bob in return for $1,000,000 and they wish to do * it on the ledger atomically. Therefore they must work together to build a transaction. * - * We assume that Alessia and Boris already found each other via some market, and have agreed the details already. + * We assume that Alice and Bob already found each other via some market, and have agreed the details already. */ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { @Before @@ -50,15 +51,10 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { @Test fun cashForCP() { - val (addr1, node1) = makeNode(inBackground = true) - val (addr2, node2) = makeNode(inBackground = true) - val backgroundThread = Executors.newSingleThreadExecutor() - val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(node1, backgroundThread)) - val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(node2, backgroundThread)) transactionGroupFor { - // Bob (S) has some cash, Alice (P) has some commercial paper she wants to sell to Bob. + // Bob (Buyer) has some cash, Alice (Seller) has some commercial paper she wants to sell to Bob. roots { transaction(CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), ALICE, 1200.DOLLARS, TEST_TX_TIME + 7.days) label "alice's paper") transaction(800.DOLLARS.CASH `owned by` BOB label "bob cash1") @@ -67,22 +63,33 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { val bobsWallet = listOf>(lookup("bob cash1"), lookup("bob cash2")) + val (alicesAddress, alicesNode) = makeNode(inBackground = true) + val (bobsAddress, bobsNode) = makeNode(inBackground = true) + + val alicesServices = MockServices(wallet = null, keyManagement = null, net = alicesNode) + val bobsServices = MockServices( + wallet = MockWalletService(bobsWallet), + keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)), + net = bobsNode + ) + + val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread)) + val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread)) + val aliceResult = tpSeller.runSeller( - addr2, - lookup("alice's paper"), - 1000.DOLLARS, - ALICE_KEY, - TEST_KEYS_TO_CORP_MAP, - DUMMY_TIMESTAMPER + bobsAddress, + TwoPartyTradeProtocol.SellerInitialArgs( + lookup("alice's paper"), + 1000.DOLLARS, + ALICE_KEY + ) ) val bobResult = tpBuyer.runBuyer( - addr1, - 1000.DOLLARS, - CommercialPaper.State::class.java, - bobsWallet, - mapOf(BOB to BOB_KEY.private), - DUMMY_TIMESTAMPER, - TEST_KEYS_TO_CORP_MAP + alicesAddress, + TwoPartyTradeProtocol.BuyerInitialArgs( + 1000.DOLLARS, + CommercialPaper.State::class.java + ) ) assertEquals(aliceResult.get(), bobResult.get()) @@ -95,14 +102,6 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { @Test fun serializeAndRestore() { - val (addr1, node1) = makeNode(inBackground = false) - var (addr2, node2) = makeNode(inBackground = false) - - val smmSeller = StateMachineManager(node1, MoreExecutors.directExecutor()) - val tpSeller = TwoPartyTradeProtocol.create(smmSeller) - val smmBuyer = StateMachineManager(node2, MoreExecutors.directExecutor()) - val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer) - transactionGroupFor { // Buyer Bob has some cash, Seller Alice has some commercial paper she wants to sell to Bob. roots { @@ -113,56 +112,71 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { val bobsWallet = listOf>(lookup("bob cash1"), lookup("bob cash2")) + val (alicesAddress, alicesNode) = makeNode(inBackground = false) + var (bobsAddress, bobsNode) = makeNode(inBackground = false) + + val bobsStorage = MockStorageService() + + val alicesServices = MockServices(wallet = null, keyManagement = null, net = alicesNode) + var bobsServices = MockServices( + wallet = MockWalletService(bobsWallet), + keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)), + net = bobsNode, + storage = bobsStorage + ) + + val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, MoreExecutors.directExecutor())) + val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor()) + val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer) + tpSeller.runSeller( - addr2, - lookup("alice's paper"), - 1000.DOLLARS, - ALICE_KEY, - TEST_KEYS_TO_CORP_MAP, - DUMMY_TIMESTAMPER + bobsAddress, + TwoPartyTradeProtocol.SellerInitialArgs( + lookup("alice's paper"), + 1000.DOLLARS, + ALICE_KEY + ) ) tpBuyer.runBuyer( - addr1, - 1000.DOLLARS, - CommercialPaper.State::class.java, - bobsWallet, - mapOf(BOB to BOB_KEY.private), - DUMMY_TIMESTAMPER, - TEST_KEYS_TO_CORP_MAP + alicesAddress, + TwoPartyTradeProtocol.BuyerInitialArgs( + 1000.DOLLARS, + CommercialPaper.State::class.java + ) ) // Everything is on this thread so we can now step through the protocol one step at a time. // Seller Alice already sent a message to Buyer Bob. Pump once: - node2.pump(false) + bobsNode.pump(false) + // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. - val storageBob = smmBuyer.saveToBytes() + // Save the state machine to "disk" (i.e. a variable, here) + assertEquals(1, bobsStorage.getMap("state machines").size) + // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk. - node2.stop() + bobsNode.stop() // Alice doesn't know that and sends Bob the now finalised transaction. Alice sends a message to a node // that has gone offline. - node1.pump(false) + alicesNode.pump(false) - // ... bring the network back up ... - node2 = network.createNodeWithID(true, addr2.id).start().get() + // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers + // that Bob was waiting on before the reboot occurred. + bobsNode = network.createNodeWithID(true, bobsAddress.id).start().get() + val smm = StateMachineManager( + MockServices(wallet = null, keyManagement = null, net = bobsNode, storage = bobsStorage), + MoreExecutors.directExecutor() + ) + + // Find the future representing the result of this state machine again. + assertEquals(1, smm.stateMachines.size) + var bobFuture = smm.stateMachines.filterIsInstance().first().resultFuture + + // Let Bob process his mailbox. + assertTrue(bobsNode.pump(false)) - // We must provide the state machines with all the stuff that couldn't be saved to disk. - var bobFuture: ListenableFuture>? = null - fun resumeStateMachine(forObj: ProtocolStateMachine<*,*>): Any { - return when (forObj) { - is TwoPartyTradeProtocol.Buyer -> { - bobFuture = forObj.resultFuture - return TwoPartyTradeProtocol.BuyerContext(bobsWallet, mapOf(BOB to BOB_KEY.private), DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP, null) - } - else -> fail() - } - } - // The act of constructing this object will re-register the message handlers that Bob was waiting on before - // the reboot occurred. - StateMachineManager(node2, MoreExecutors.directExecutor(), storageBob, ::resumeStateMachine) - assertTrue(node2.pump(false)) // Bob is now finished and has the same transaction as Alice. - val tx: Pair = bobFuture!!.get() + val tx = bobFuture.get() txns.add(tx.second) verify() } diff --git a/src/test/kotlin/core/serialization/TransactionSerializationTests.kt b/src/test/kotlin/core/serialization/TransactionSerializationTests.kt index d528102dca..8c35df1d8b 100644 --- a/src/test/kotlin/core/serialization/TransactionSerializationTests.kt +++ b/src/test/kotlin/core/serialization/TransactionSerializationTests.kt @@ -88,7 +88,7 @@ class TransactionSerializationTests { fun timestamp() { tx.signWith(TestUtils.keypair) val ttx = tx.toSignedTransaction().toTimestampedTransactionWithoutTime() - val ltx = ttx.verifyToLedgerTransaction(DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP) + val ltx = ttx.verifyToLedgerTransaction(DUMMY_TIMESTAMPER, MockIdentityService) assertEquals(tx.commands().map { it.command }, ltx.commands.map { it.value }) assertEquals(tx.inputStates(), ltx.inStateRefs) assertEquals(tx.outputStates(), ltx.outStates) @@ -97,7 +97,7 @@ class TransactionSerializationTests { val ltx2: LedgerTransaction = tx. toSignedTransaction(). toTimestampedTransaction(DUMMY_TIMESTAMPER). - verifyToLedgerTransaction(DUMMY_TIMESTAMPER, TEST_KEYS_TO_CORP_MAP) + verifyToLedgerTransaction(DUMMY_TIMESTAMPER, MockIdentityService) assertEquals(TEST_TX_TIME, ltx2.time) } } \ No newline at end of file diff --git a/src/test/kotlin/core/testutils/TestUtils.kt b/src/test/kotlin/core/testutils/TestUtils.kt index f900d7c43f..38714fe69b 100644 --- a/src/test/kotlin/core/testutils/TestUtils.kt +++ b/src/test/kotlin/core/testutils/TestUtils.kt @@ -13,15 +13,19 @@ package core.testutils import com.google.common.io.BaseEncoding import contracts.* import core.* +import core.messaging.MessagingSystem import core.visualiser.GraphVisualiser import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.DataInputStream import java.io.DataOutputStream +import java.security.KeyPair import java.security.KeyPairGenerator +import java.security.PrivateKey import java.security.PublicKey import java.time.Instant import java.util.* +import javax.annotation.concurrent.ThreadSafe import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.fail @@ -89,6 +93,53 @@ class DummyTimestamper(private val time: Instant = TEST_TX_TIME) : TimestamperSe val DUMMY_TIMESTAMPER = DummyTimestamper() +object MockIdentityService : IdentityService { + override fun partyFromKey(key: PublicKey): Party? = TEST_KEYS_TO_CORP_MAP[key] +} + +class MockKeyManagementService( + override val keys: Map, + val nextKeys: MutableList = arrayListOf(KeyPairGenerator.getInstance("EC").genKeyPair()) +) : KeyManagementService { + override fun freshKey() = nextKeys.removeAt(nextKeys.lastIndex) +} + +class MockWalletService(val states: List>) : WalletService { + override val currentWallet = Wallet(states) +} + +@ThreadSafe +class MockStorageService : StorageService { + private val mapOfMaps = HashMap>() + + @Synchronized + override fun getMap(tableName: String): MutableMap { + return mapOfMaps.getOrPut(tableName) { Collections.synchronizedMap(HashMap()) } as MutableMap + } +} + +class MockServices( + val wallet: WalletService?, + val keyManagement: KeyManagementService?, + val net: MessagingSystem?, + val identity: IdentityService? = MockIdentityService, + val storage: StorageService? = MockStorageService(), + val timestamping: TimestamperService? = DUMMY_TIMESTAMPER +) : ServiceHub { + override val walletService: WalletService + get() = wallet ?: throw UnsupportedOperationException() + override val keyManagementService: KeyManagementService + get() = keyManagement ?: throw UnsupportedOperationException() + override val identityService: IdentityService + get() = identity ?: throw UnsupportedOperationException() + override val timestampingService: TimestamperService + get() = timestamping ?: throw UnsupportedOperationException() + override val networkService: MessagingSystem + get() = net ?: throw UnsupportedOperationException() + override val storageService: StorageService + get() = storage ?: throw UnsupportedOperationException() +} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // // Defines a simple DSL for building pseudo-transactions (not the same as the wire protocol) for testing purposes. @@ -230,7 +281,7 @@ class TransactionGroupDSL(private val stateType: Class) { */ fun toLedgerTransaction(time: Instant): LedgerTransaction { val wireCmds = commands.map { WireCommand(it.value, it.signers) } - return WireTransaction(inStates, outStates.map { it.state }, wireCmds).toLedgerTransaction(time, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + return WireTransaction(inStates, outStates.map { it.state }, wireCmds).toLedgerTransaction(time, MockIdentityService, SecureHash.randomSHA256()) } } @@ -266,7 +317,7 @@ class TransactionGroupDSL(private val stateType: Class) { fun transaction(vararg outputStates: LabeledOutput) { val outs = outputStates.map { it.state } val wtx = WireTransaction(emptyList(), outs, emptyList()) - val ltx = wtx.toLedgerTransaction(TEST_TX_TIME, TEST_KEYS_TO_CORP_MAP, SecureHash.randomSHA256()) + val ltx = wtx.toLedgerTransaction(TEST_TX_TIME, MockIdentityService, SecureHash.randomSHA256()) for ((index, state) in outputStates.withIndex()) { val label = state.label!! labelToRefs[label] = ContractStateRef(ltx.hash, index)