diff --git a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt index 230621d8f1..8e1b3f42d0 100644 --- a/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt +++ b/src/main/kotlin/contracts/protocols/TwoPartyTradeProtocol.kt @@ -12,18 +12,22 @@ import co.paralleluniverse.fibers.Suspendable import contracts.Cash import contracts.sumCashBy import core.* +import core.messaging.LegallyIdentifiableNode import core.messaging.ProtocolStateMachine import core.messaging.SingleMessageRecipient import core.messaging.StateMachineManager +import core.node.TimestamperClient import core.serialization.deserialize import core.utilities.trace import java.security.KeyPair import java.security.PublicKey +import java.time.Instant // TODO: Get rid of the "initial args" concept and just use the class c'tors, now we are using Quasar. /** - * This asset trading protocol has two parties (B and S for buyer and seller) and the following steps: + * This asset trading protocol implements a "delivery vs payment" type swap. It has two parties (B and S for buyer + * and seller) and the following steps: * * 1. S sends the [StateAndRef] pointing to what they want to sell to B, along with info about the price they require * B to pay. For example this has probably been agreed on an exchange. @@ -67,14 +71,14 @@ abstract class TwoPartyTradeProtocol { abstract class Seller : ProtocolStateMachine>() companion object { - @JvmStatic fun create(smm: StateMachineManager): TwoPartyTradeProtocol { - return TwoPartyTradeProtocolImpl(smm) + @JvmStatic fun create(smm: StateMachineManager, timestampingAuthority: LegallyIdentifiableNode): TwoPartyTradeProtocol { + return TwoPartyTradeProtocolImpl(smm, timestampingAuthority) } } } -/** The implementation of the [TwoPartyTradeProtocol] base class. */ -private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : TwoPartyTradeProtocol() { +private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager, + private val timestampingAuthority: LegallyIdentifiableNode) : TwoPartyTradeProtocol() { companion object { val TRADE_TOPIC = "com.r3cev.protocols.trade" } @@ -87,13 +91,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : val sessionID: Long ) - // The seller's side of the protocol. IMPORTANT: This class is loaded in a separate classloader and auto-mangled - // by JavaFlow. Therefore, we cannot cast the object to Seller and poke it directly because the class we'd be - // trying to poke at is different to the one we saw at compile time, so we'd get ClassCastExceptions. All - // interaction with this class must be through either interfaces, the supertype, or objects passed to and from - // the continuation by the state machine framework. Please refer to the documentation website (docs/build/html) to - // learn more about the protocol state machine framework. - class SellerImpl : Seller() { + class SellerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: LegallyIdentifiableNode) : Seller() { @Suspendable override fun call(args: SellerInitialArgs): Pair { val sessionID = random63BitValue() @@ -101,7 +99,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // Make the first message we'll send to kick off the protocol. val hello = SellerTradeInfo(args.assetToSell, args.price, args.myKeyPair.public, sessionID) - val partialTX = sendAndReceive(TRADE_TOPIC, args.buyerSessionID, sessionID, hello) + val partialTX = sendAndReceive(TRADE_TOPIC, otherSide, args.buyerSessionID, sessionID, hello) logger.trace { "Received partially signed transaction" } partialTX.verifySignatures() @@ -122,13 +120,17 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // express protocol state machines on top of the messaging layer. } + // Sign with our key and get the timestamping authorities key as well. + // These two steps could be done in parallel, in theory. val ourSignature = args.myKeyPair.signWithECDSA(partialTX.txBits) - val fullySigned: SignedWireTransaction = partialTX.copy(sigs = partialTX.sigs + ourSignature) + val tsaSig = TimestamperClient(this, timestampingAuthority).timestamp(partialTX.txBits) + val fullySigned = partialTX.withAdditionalSignature(tsaSig).withAdditionalSignature(ourSignature) + // We should run it through our full TransactionGroup of all transactions here. - fullySigned.verify() + logger.trace { "Built finished transaction, sending back to secondary!" } - send(TRADE_TOPIC, args.buyerSessionID, fullySigned) + send(TRADE_TOPIC, otherSide, args.buyerSessionID, fullySigned) return Pair(wtx, fullySigned.verifyToLedgerTransaction(serviceHub.identityService)) } @@ -140,7 +142,7 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : } // The buyer's side of the protocol. See note above Seller to learn about the caveats here. - class BuyerImpl : Buyer() { + class BuyerImpl(private val otherSide: SingleMessageRecipient, private val timestampingAuthority: Party) : Buyer() { @Suspendable override fun call(args: BuyerInitialArgs): Pair { // Wait for a trade request to come in on our pre-provided session ID. @@ -175,6 +177,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : ptx.addOutputState(state) ptx.addCommand(command, tradeRequest.assetForSale.state.owner) + // And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt + // to have one. + ptx.setTime(Instant.now(), timestampingAuthority, 30.seconds) + // Now sign the transaction with whatever keys we need to move the cash. for (k in cashSigningPubKeys) { val priv = serviceHub.keyManagementService.toPrivate(k) @@ -190,7 +196,9 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : // TODO: Protect against the buyer terminating here and leaving us in the lurch without the final tx. // TODO: Protect against a malicious buyer sending us back a different transaction to the one we built. - val fullySigned = sendAndReceive(TRADE_TOPIC, tradeRequest.sessionID, args.sessionID, stx) + + val fullySigned = sendAndReceive(TRADE_TOPIC, otherSide, tradeRequest.sessionID, + args.sessionID, stx) logger.trace { "Got fully signed transaction, verifying ... "} @@ -203,10 +211,10 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) : } override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller { - return smm.add(otherSide, args, "$TRADE_TOPIC.seller", SellerImpl::class.java) + return smm.add(args, "$TRADE_TOPIC.seller", SellerImpl(otherSide, timestampingAuthority)) } override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer { - return smm.add(otherSide, args, "$TRADE_TOPIC.buyer", BuyerImpl::class.java) + return smm.add(args, "$TRADE_TOPIC.buyer", BuyerImpl(otherSide, timestampingAuthority.identity)) } } \ No newline at end of file diff --git a/src/main/kotlin/core/Crypto.kt b/src/main/kotlin/core/Crypto.kt index 1628096467..a5d30f742b 100644 --- a/src/main/kotlin/core/Crypto.kt +++ b/src/main/kotlin/core/Crypto.kt @@ -57,6 +57,7 @@ open class DigitalSignature(bits: ByteArray, val covering: Int = 0) : OpaqueByte /** A digital signature that identifies who the public key is owned by. */ open class WithKey(val by: PublicKey, bits: ByteArray, covering: Int = 0) : DigitalSignature(bits, covering) { fun verifyWithECDSA(content: ByteArray) = by.verifyWithECDSA(content, this) + fun verifyWithECDSA(content: OpaqueBytes) = by.verifyWithECDSA(content.bits, this) } class LegallyIdentifiable(val signer: Party, bits: ByteArray, covering: Int) : WithKey(signer.owningKey, bits, covering) diff --git a/src/main/kotlin/core/Services.kt b/src/main/kotlin/core/Services.kt index eca9519f7a..18e6979107 100644 --- a/src/main/kotlin/core/Services.kt +++ b/src/main/kotlin/core/Services.kt @@ -8,6 +8,7 @@ package core +import co.paralleluniverse.fibers.Suspendable import core.messaging.MessagingService import core.serialization.SerializedBytes import java.security.KeyPair @@ -78,6 +79,7 @@ interface IdentityService { * themselves. */ interface TimestamperService { + @Suspendable fun timestamp(wtxBytes: SerializedBytes): DigitalSignature.LegallyIdentifiable /** The name+pubkey that this timestamper will sign with. */ @@ -99,6 +101,13 @@ object DummyTimestampingAuthority { */ interface StorageService { fun getMap(tableName: String): MutableMap + + /** + * Returns the legal identity that this node is configured with. Assumed to be initialised when the node is + * first installed. + */ + val myLegalIdentity: Party + val myLegalIdentityKey: KeyPair } /** @@ -110,7 +119,6 @@ interface ServiceHub { val walletService: WalletService val keyManagementService: KeyManagementService val identityService: IdentityService - val timestampingService: TimestamperService val storageService: StorageService val networkService: MessagingService } \ No newline at end of file diff --git a/src/main/kotlin/core/Transactions.kt b/src/main/kotlin/core/Transactions.kt index 46379aa3b2..c4d28be2c4 100644 --- a/src/main/kotlin/core/Transactions.kt +++ b/src/main/kotlin/core/Transactions.kt @@ -8,6 +8,8 @@ package core +import co.paralleluniverse.fibers.Suspendable +import core.node.TimestampingError import core.serialization.SerializedBytes import core.serialization.deserialize import core.serialization.serialize @@ -106,15 +108,11 @@ data class SignedWireTransaction(val txBits: SerializedBytes, v verify() return tx.toLedgerTransaction(identityService, id) } + + /** Returns the same transaction but with an additional (unchecked) signature */ + fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copy(sigs = sigs + sig) } - -/** - * Thrown if an attempt is made to timestamp a transaction using a trusted timestamper, but the time on the transaction - * is too far in the past or future relative to the local clock and thus the timestamper would reject it. - */ -class NotOnTimeException : Exception() - /** A mutable transaction that's in the process of being built, before all signatures are present. */ class TransactionBuilder(private val inputStates: MutableList = arrayListOf(), private val outputStates: MutableList = arrayListOf(), @@ -163,6 +161,20 @@ class TransactionBuilder(private val inputStates: MutableList currentSigs.add(key.signWithECDSA(data.bits)) } + /** + * Checks that the given signature matches one of the commands and that it is a correct signature over the tx, then + * adds it. + * + * @throws SignatureException if the signature didn't match the transaction contents + * @throws IllegalArgumentException if the signature key doesn't appear in any command. + */ + fun checkAndAddSignature(sig: DigitalSignature.WithKey) { + require(commands.count { it.pubkeys.contains(sig.by) } > 0) { "Signature key doesn't match any command" } + val data = toWireTransaction().serialize() + sig.verifyWithECDSA(data.bits) + currentSigs.add(sig) + } + /** * Uses the given timestamper service to request a signature over the WireTransaction be added. There must always be * at least one such signature, but others may be added as well. You may want to have multiple redundant timestamps @@ -173,15 +185,14 @@ class TransactionBuilder(private val inputStates: MutableList * * The signature of the trusted timestamper merely asserts that the time field of this transaction is valid. */ + @Suspendable fun timestamp(timestamper: TimestamperService, clock: Clock = Clock.systemUTC()) { - // TODO: Once we switch to a more advanced bytecode rewriting framework, we can call into a real implementation. - check(timestamper.javaClass.simpleName == "DummyTimestamper") val t = time ?: throw IllegalStateException("Timestamping requested but no time was inserted into the transaction") // Obviously this is just a hard-coded dummy value for now. val maxExpectedLatency = 5.seconds if (Duration.between(clock.instant(), t.before) > maxExpectedLatency) - throw NotOnTimeException() + throw TimestampingError.NotOnTimeException() // The timestamper may also throw NotOnTimeException if our clocks are desynchronised or if we are right on the // boundary of t.notAfter and network latency pushes us over the edge. By "synchronised" here we mean relative diff --git a/src/main/kotlin/core/Utils.kt b/src/main/kotlin/core/Utils.kt index 7f72633da4..29e08313bb 100644 --- a/src/main/kotlin/core/Utils.kt +++ b/src/main/kotlin/core/Utils.kt @@ -44,4 +44,6 @@ fun SettableFuture.setFrom(logger: Logger? = null, block: () -> T): Setta } // Simple infix function to add back null safety that the JDK lacks: timeA until timeB -infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive) \ No newline at end of file +infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive) + +val RunOnCallerThread = MoreExecutors.directExecutor() \ No newline at end of file diff --git a/src/main/kotlin/core/messaging/InMemoryNetwork.kt b/src/main/kotlin/core/messaging/InMemoryNetwork.kt index d646fce809..b8ebf04d3c 100644 --- a/src/main/kotlin/core/messaging/InMemoryNetwork.kt +++ b/src/main/kotlin/core/messaging/InMemoryNetwork.kt @@ -11,7 +11,11 @@ package core.messaging import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors +import core.Party +import core.node.TimestamperNodeService import core.sha256 +import core.utilities.loggerFor +import java.security.KeyPairGenerator import java.time.Instant import java.util.* import java.util.concurrent.Executor @@ -31,13 +35,15 @@ import kotlin.concurrent.thread @ThreadSafe public class InMemoryNetwork { private var counter = 0 // -1 means stopped. - private val networkMap = HashMap() + private val handleNodeMap = HashMap() // All messages are kept here until the messages are pumped off the queue by a caller to the node class. // Queues are created on-demand when a message is sent to an address: the receiving node doesn't have to have // been created yet. If the node identified by the given handle has gone away/been shut down then messages // stack up here waiting for it to come back. The intent of this is to simulate a reliable messaging network. private val messageQueues = HashMap>() + val nodes: List @Synchronized get() = handleNodeMap.values.toList() + /** * Creates a node and returns the new object that identifies its location on the network to senders, and the * [Node] that the recipient/in-memory node uses to receive messages and send messages itself. @@ -69,7 +75,7 @@ public class InMemoryNetwork { is AllPossibleRecipients -> { // This means all possible recipients _that the network knows about at the time_, not literally everyone // who joins into the indefinite future. - for (handle in networkMap.keys) + for (handle in handleNodeMap.keys) getQueueForHandle(handle).add(message) } else -> throw IllegalArgumentException("Unknown type of recipient handle") @@ -78,7 +84,7 @@ public class InMemoryNetwork { @Synchronized private fun netNodeHasShutdown(handle: Handle) { - networkMap.remove(handle) + handleNodeMap.remove(handle) } @Synchronized @@ -90,11 +96,11 @@ public class InMemoryNetwork { fun stop() { // toArrayList here just copies the collection, which we need because node.stop() will delete itself from // the network map by calling netNodeHasShutdown. So we would get a CoModException if we didn't copy first. - for (node in networkMap.values.toArrayList()) + for (node in handleNodeMap.values.toArrayList()) node.stop() counter = -1 - networkMap.clear() + handleNodeMap.clear() messageQueues.clear() } @@ -102,7 +108,7 @@ public class InMemoryNetwork { override fun start(): ListenableFuture { synchronized(this@InMemoryNetwork) { val node = Node(manuallyPumped, id) - networkMap[id] = node + handleNodeMap[id] = node return Futures.immediateFuture(node) } } @@ -114,6 +120,20 @@ public class InMemoryNetwork { override fun hashCode() = id.hashCode() } + private var timestampingAdvert: LegallyIdentifiableNode? = null + + @Synchronized + fun setupTimestampingNode(manuallyPumped: Boolean): Pair { + check(timestampingAdvert == null) + val (handle, builder) = createNode(manuallyPumped) + val node = builder.start().get() + val key = KeyPairGenerator.getInstance("EC").genKeyPair() + val identity = Party("Unit test timestamping authority", key.public) + TimestamperNodeService(node, identity, key) + timestampingAdvert = LegallyIdentifiableNode(handle, identity) + return Pair(timestampingAdvert!!, node) + } + /** * An [Node] provides a [MessagingService] that isn't backed by any kind of network or disk storage * system, but just uses regular queues on the heap instead. It is intended for unit testing and developer convenience @@ -132,6 +152,10 @@ public class InMemoryNetwork { override val myAddress: SingleMessageRecipient = handle + override val networkMap: NetworkMap get() = object : NetworkMap { + override val timestampingNodes = if (timestampingAdvert != null) listOf(timestampingAdvert!!) else emptyList() + } + protected val backgroundThread = if (manuallyPumped) null else thread(isDaemon = true, name = "In-memory message dispatcher ") { while (!currentThread.isInterrupted) { try { @@ -228,7 +252,13 @@ public class InMemoryNetwork { for (handler in deliverTo) { // Now deliver via the requested executor, or on this thread if no executor was provided at registration time. - (handler.executor ?: MoreExecutors.directExecutor()).execute { handler.callback(message, handler) } + (handler.executor ?: MoreExecutors.directExecutor()).execute { + try { + handler.callback(message, handler) + } catch(e: Exception) { + loggerFor().error("Caught exception in handler for $this/${handler.topic}", e) + } + } } return true diff --git a/src/main/kotlin/core/messaging/Messaging.kt b/src/main/kotlin/core/messaging/Messaging.kt index e7a1b6b3c2..493de6cf50 100644 --- a/src/main/kotlin/core/messaging/Messaging.kt +++ b/src/main/kotlin/core/messaging/Messaging.kt @@ -66,8 +66,11 @@ interface MessagingService { */ fun createMessage(topic: String, data: ByteArray): Message - /** Returns an address that refers to this node */ + /** Returns an address that refers to this node. */ val myAddress: SingleMessageRecipient + + /** Allows you to look up services and nodes that are available on the network. */ + val networkMap: NetworkMap } /** diff --git a/src/main/kotlin/core/messaging/NetworkMap.kt b/src/main/kotlin/core/messaging/NetworkMap.kt new file mode 100644 index 0000000000..555081b2b3 --- /dev/null +++ b/src/main/kotlin/core/messaging/NetworkMap.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.messaging + +import core.Party + +/** Info about a network node that has is operated by some sort of verified identity. */ +data class LegallyIdentifiableNode(val address: SingleMessageRecipient, val identity: Party) + +/** + * A NetworkMap allows you to look up various types of services provided by nodes on the network, and find node + * addresses given legal identities (NB: not all nodes may have legal identities). + * + * A real implementation would probably do RPCs to a lookup service which might in turn be backed by a ZooKeeper + * cluster or equivalent. + * + * For now, this class is truly minimal. + */ +interface NetworkMap { + val timestampingNodes: List +} diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt index 8f1717a07d..207433725d 100644 --- a/src/main/kotlin/core/messaging/StateMachines.kt +++ b/src/main/kotlin/core/messaging/StateMachines.kt @@ -44,6 +44,7 @@ import javax.annotation.concurrent.ThreadSafe * a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point. * * TODO: The framework should propagate exceptions and handle error handling automatically. + * TODO: Session IDs should be set up and propagated automatically, on demand. * TODO: This needs extension to the >2 party case. * TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised * continuation is always unique? @@ -75,7 +76,6 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). private class Checkpoint( val serialisedFiber: ByteArray, - val otherSide: MessageRecipients, val loggerName: String, val awaitingTopic: String, val awaitingObjectOfType: String // java class name @@ -103,7 +103,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) serviceHub.networkService.runOnNextMessage(topic, runInThread) { netMsg -> val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), awaitingObjectOfType) logger.trace { "<- $topic : message of type ${obj.javaClass.name}" } - iterateStateMachine(psm, serviceHub.networkService, logger, obj, checkpoint.otherSide, checkpointKey) { + iterateStateMachine(psm, serviceHub.networkService, logger, obj, checkpointKey) { Fiber.unparkDeserialized(it, SameThreadFiberScheduler) } } @@ -118,16 +118,14 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) } /** - * Kicks off a brand new state machine of the given class. It will send messages to the network node identified by - * the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call - * method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends - * and will be removed once it completes. + * Kicks off a brand new state machine of the given class. It will log with the named logger, and the + * [initialArgs] object will be passed to the call method of the [ProtocolStateMachine] object. + * The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is + * restarted with checkpointed state machines in the storage service. */ - fun , I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String, - klass: Class): T { + fun , I> add(initialArgs: I, loggerName: String, fiber: T): T { val logger = LoggerFactory.getLogger(loggerName) - val fiber = klass.newInstance() - iterateStateMachine(fiber, serviceHub.networkService, logger, initialArgs, otherSide, null) { + iterateStateMachine(fiber, serviceHub.networkService, logger, initialArgs, null) { it.start() } return fiber @@ -144,24 +142,23 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) } private fun iterateStateMachine(psm: ProtocolStateMachine<*, *>, net: MessagingService, logger: Logger, - obj: Any?, otherSide: MessageRecipients, prevCheckpointKey: SecureHash?, - resumeFunc: (ProtocolStateMachine<*, *>) -> Unit) { + obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*, *>) -> Unit) { val onSuspend = fun(request: FiberRequest, serFiber: ByteArray) { // We have a request to do something: send, receive, or send-and-receive. if (request is FiberRequest.ExpectingResponse<*>) { // Prepare a listener on the network that runs in the background thread when we received a message. - checkpointAndSetupMessageHandler(logger, net, psm, otherSide, request.responseType, + checkpointAndSetupMessageHandler(logger, net, psm, request.responseType, "${request.topic}.${request.sessionIDForReceive}", prevCheckpointKey, serFiber) } // If an object to send was provided (not null), send it now. request.obj?.let { val topic = "${request.topic}.${request.sessionIDForSend}" - logger.trace { "-> $topic : message of type ${it.javaClass.name}" } - net.send(net.createMessage(topic, it.serialize().bits), otherSide) + logger.trace { "-> ${request.destination}/$topic : message of type ${it.javaClass.name}" } + net.send(net.createMessage(topic, it.serialize().bits), request.destination!!) } if (request is FiberRequest.NotExpectingResponse) { // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - iterateStateMachine(psm, net, logger, null, otherSide, prevCheckpointKey) { + iterateStateMachine(psm, net, logger, null, prevCheckpointKey) { Fiber.unpark(it, QUASAR_UNBLOCKER) } } @@ -185,16 +182,16 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) } private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*,*>, - otherSide: MessageRecipients, responseType: Class<*>, - topic: String, prevCheckpointKey: SecureHash?, serialisedFiber: ByteArray) { - val checkpoint = Checkpoint(serialisedFiber, otherSide, logger.name, topic, responseType.name) + responseType: Class<*>, topic: String, prevCheckpointKey: SecureHash?, + serialisedFiber: ByteArray) { + val checkpoint = Checkpoint(serialisedFiber, logger.name, topic, responseType.name) val curPersistedBytes = checkpoint.serialize().bits persistCheckpoint(prevCheckpointKey, curPersistedBytes) val newCheckpointKey = curPersistedBytes.sha256() net.runOnNextMessage(topic, runInThread) { netMsg -> val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType) logger.trace { "<- $topic : message of type ${obj.javaClass.name}" } - iterateStateMachine(psm, net, logger, obj, otherSide, newCheckpointKey) { + iterateStateMachine(psm, net, logger, obj, newCheckpointKey) { Fiber.unpark(it, QUASAR_UNBLOCKER) } } @@ -204,9 +201,11 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor()) /** - * The base class that should be used by any object that wishes to act as a protocol state machine. Sub-classes should - * override the [call] method and return whatever the final result of the protocol is. Inside the call method, - * the rules of normal object oriented programming are a little different: + * The base class that should be used by any object that wishes to act as a protocol state machine. The type variable + * C is the type of the initial arguments. R is the type of the return. + * + * Sub-classes should override the [call] method and return whatever the final result of the protocol is. Inside the + * call method, the rules of normal object oriented programming are a little different: * * - You can call send/receive/sendAndReceive in order to suspend the state machine and request network interaction. * This does not block a thread and when a state machine is suspended like this, it will be serialised and written @@ -225,11 +224,18 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiber // These fields shouldn't be serialised, so they are marked @Transient. @Transient private var suspendFunc: ((result: FiberRequest, serFiber: ByteArray) -> Unit)? = null @Transient private var resumeWithObject: Any? = null - @Transient protected lateinit var serviceHub: ServiceHub + @Transient lateinit var serviceHub: ServiceHub @Transient protected lateinit var logger: Logger - @Transient private var _resultFuture: SettableFuture = SettableFuture.create() + @Transient private var _resultFuture: SettableFuture? = SettableFuture.create() + /** This future will complete when the call method returns. */ - val resultFuture: ListenableFuture get() = _resultFuture + val resultFuture: ListenableFuture get() { + return _resultFuture ?: run { + val f = SettableFuture.create() + _resultFuture = f + return f + } + } fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?, logger: Logger, suspendFunc: (FiberRequest, ByteArray) -> Unit) { @@ -244,9 +250,9 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiber @Suspendable @Suppress("UNCHECKED_CAST") override fun run(): R { - val result = call(resumeWithObject!! as C) + val result = call(resumeWithObject as C) if (result != null) - _resultFuture.set(result) + (resultFuture as SettableFuture).set(result) return result } @@ -268,42 +274,46 @@ abstract class ProtocolStateMachine : Fiber("protocol", SameThreadFiber } @Suspendable @Suppress("UNCHECKED_CAST") - protected fun sendAndReceive(topic: String, sessionIDForSend: Long, sessionIDForReceive: Long, - obj: Any, recvType: Class): T { - val result = FiberRequest.ExpectingResponse(topic, sessionIDForSend, sessionIDForReceive, obj, recvType) - return suspendAndExpectReceive(result) + fun sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long, sessionIDForReceive: Long, + obj: Any, recvType: Class): T { + val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, obj, recvType) + return suspendAndExpectReceive(result) } @Suspendable - protected fun receive(topic: String, sessionIDForReceive: Long, recvType: Class): T { - val result = FiberRequest.ExpectingResponse(topic, -1, sessionIDForReceive, null, recvType) - return suspendAndExpectReceive(result) + fun receive(topic: String, sessionIDForReceive: Long, recvType: Class): T { + val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) + return suspendAndExpectReceive(result) } @Suspendable - protected fun send(topic: String, sessionID: Long, obj: Any) { - val result = FiberRequest.NotExpectingResponse(topic, sessionID, obj) + fun send(topic: String, destination: MessageRecipients, sessionID: Long, obj: Any) { + val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, obj) Fiber.parkAndSerialize { fiber, writer -> suspendFunc!!(result, writer.write(fiber)) } } // Convenience functions for Kotlin users. - inline protected fun sendAndReceive(topic: String, sessionIDForSend: Long, - sessionIDForReceive: Long, obj: Any): R { - return sendAndReceive(topic, sessionIDForSend, sessionIDForReceive, obj, R::class.java) + inline fun sendAndReceive(topic: String, destination: MessageRecipients, sessionIDForSend: Long, + sessionIDForReceive: Long, obj: Any): R { + return sendAndReceive(topic, destination, sessionIDForSend, sessionIDForReceive, obj, R::class.java) } - inline protected fun receive(topic: String, sessionIDForReceive: Long): R { + inline fun receive(topic: String, sessionIDForReceive: Long): R { return receive(topic, sessionIDForReceive, R::class.java) } } -open class FiberRequest(val topic: String, val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) { +// TODO: Clean this up +open class FiberRequest(val topic: String, val destination: MessageRecipients?, + val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) { class ExpectingResponse( topic: String, + destination: MessageRecipients?, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any?, val responseType: Class - ) : FiberRequest(topic, sessionIDForSend, sessionIDForReceive, obj) + ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) - class NotExpectingResponse(topic: String, sessionIDForSend: Long, obj: Any?) : FiberRequest(topic, sessionIDForSend, -1, obj) + class NotExpectingResponse(topic: String, destination: MessageRecipients, sessionIDForSend: Long, obj: Any?) + : FiberRequest(topic, destination, sessionIDForSend, -1, obj) } \ No newline at end of file diff --git a/src/main/kotlin/core/node/TimestamperNodeService.kt b/src/main/kotlin/core/node/TimestamperNodeService.kt new file mode 100644 index 0000000000..21f4ced7b8 --- /dev/null +++ b/src/main/kotlin/core/node/TimestamperNodeService.kt @@ -0,0 +1,125 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.node + +import co.paralleluniverse.common.util.VisibleForTesting +import co.paralleluniverse.fibers.Suspendable +import core.* +import core.messaging.LegallyIdentifiableNode +import core.messaging.MessageRecipients +import core.messaging.MessagingService +import core.messaging.ProtocolStateMachine +import core.serialization.SerializedBytes +import core.serialization.deserialize +import core.serialization.serialize +import org.slf4j.LoggerFactory +import java.security.KeyPair +import java.time.Clock +import java.time.Duration +import javax.annotation.concurrent.ThreadSafe + +class TimestampingMessages { + // TODO: Improve the messaging api to have a notion of sender+replyTo topic (optional?) + data class Request(val tx: SerializedBytes, val replyTo: MessageRecipients, val replyToTopic: String) +} + +sealed class TimestampingError : Exception() { + class RequiresExactlyOneCommand : TimestampingError() + /** + * Thrown if an attempt is made to timestamp a transaction using a trusted timestamper, but the time on the + * transaction is too far in the past or future relative to the local clock and thus the timestamper would reject + * it. + */ + class NotOnTimeException : TimestampingError() + + /** Thrown if the command in the transaction doesn't list this timestamping authorities public key as a signer */ + class NotForMe : TimestampingError() +} + +/** + * This class implements the server side of the timestamping protocol, using the local clock. A future version might + * add features like checking against other NTP servers to make sure the clock hasn't drifted by too much. + * + * See the doc site to learn more about timestamping authorities (nodes) and the role they play in the data model. + */ +@ThreadSafe +class TimestamperNodeService(private val net: MessagingService, + private val identity: Party, + private val signingKey: KeyPair, + private val clock: Clock = Clock.systemDefaultZone(), + val tolerance: Duration = 30.seconds) { + companion object { + val TIMESTAMPING_PROTOCOL_TOPIC = "dlg.timestamping.request" + + private val logger = LoggerFactory.getLogger(TimestamperNodeService::class.java) + } + + init { + require(identity.owningKey == signingKey.public) + net.addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC + ".0", null) { message, r -> + try { + val req = message.data.deserialize() + val signature = processRequest(req) + val msg = net.createMessage(req.replyToTopic, signature.serialize().bits) + net.send(msg, req.replyTo) + } catch(e: TimestampingError) { + logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}") + } catch(e: Exception) { + logger.error("Exception during timestamping", e) + } + } + } + + @VisibleForTesting + fun processRequest(req: TimestampingMessages.Request): DigitalSignature.LegallyIdentifiable { + // We don't bother verifying signatures anything about the transaction here: we simply don't need to see anything + // except the relevant command, and a future privacy upgrade should ensure we only get a torn-off command + // rather than the full transaction. + val tx = req.tx.deserialize() + val cmd = tx.commands.filter { it.data is TimestampCommand }.singleOrNull() + if (cmd == null) + throw TimestampingError.RequiresExactlyOneCommand() + if (!cmd.pubkeys.contains(identity.owningKey)) + throw TimestampingError.NotForMe() + val tsCommand = cmd.data as TimestampCommand + + val before = tsCommand.before + val after = tsCommand.after + + val now = clock.instant() + + // We don't need to test for (before == null && after == null) or backwards bounds because the TimestampCommand + // constructor already checks that. + + if (before != null && before until now > tolerance) + throw TimestampingError.NotOnTimeException() + if (after != null && now until after > tolerance) + throw TimestampingError.NotOnTimeException() + + return signingKey.signWithECDSA(req.tx.bits, identity) + } +} + +@ThreadSafe +class TimestamperClient(private val psm: ProtocolStateMachine<*, *>, private val node: LegallyIdentifiableNode) : TimestamperService { + override val identity: Party = node.identity + + @Suspendable + override fun timestamp(wtxBytes: SerializedBytes): DigitalSignature.LegallyIdentifiable { + val sessionID = random63BitValue() + val replyTopic = "${TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID" + val req = TimestampingMessages.Request(wtxBytes, psm.serviceHub.networkService.myAddress, replyTopic) + val signature = psm.sendAndReceive( + TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req) + // Check that the timestamping authority gave us back a valid signature and didn't break somehow + signature.verifyWithECDSA(wtxBytes) + return signature + } +} + diff --git a/src/main/kotlin/core/serialization/Kryo.kt b/src/main/kotlin/core/serialization/Kryo.kt index d47bce6a54..b267fd4b2a 100644 --- a/src/main/kotlin/core/serialization/Kryo.kt +++ b/src/main/kotlin/core/serialization/Kryo.kt @@ -8,6 +8,8 @@ package core.serialization +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.io.serialization.kryo.KryoSerializer import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer @@ -16,13 +18,13 @@ import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.serializers.JavaSerializer import core.SecureHash import core.SignedWireTransaction -import core.TimestampCommand import core.sha256 import de.javakaffee.kryoserializers.ArraysAsListSerializer import org.objenesis.strategy.StdInstantiatorStrategy import java.io.ByteArrayOutputStream import java.lang.reflect.InvocationTargetException import java.security.KeyPairGenerator +import java.time.Instant import java.util.* import kotlin.reflect.KClass import kotlin.reflect.KMutableProperty @@ -163,17 +165,30 @@ fun createKryo(k: Kryo = Kryo()): Kryo { register(Arrays.asList( "" ).javaClass, ArraysAsListSerializer()); - val keyPair = KeyPairGenerator.getInstance("EC").genKeyPair() + // Because we like to stick a Kryo object in a ThreadLocal to speed things up a bit, we can end up trying to + // serialise the Kryo object itself when suspending a fiber. That's dumb, useless AND can cause crashes, so + // we avoid it here. + register(Kryo::class.java, object : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: Kryo) { + } + + override fun read(kryo: Kryo, input: Input, type: Class): Kryo { + return createKryo((Fiber.getFiberSerializer() as KryoSerializer).kryo) + } + }) + + // Some things where the JRE provides an efficient custom serialisation. val ser = JavaSerializer() + val keyPair = KeyPairGenerator.getInstance("EC").genKeyPair() register(keyPair.public.javaClass, ser) register(keyPair.private.javaClass, ser) + register(Instant::class.java, ser) // Some classes have to be handled with the ImmutableClassSerializer because they need to have their // constructors be invoked (typically for lazy members). val immutables = listOf( SignedWireTransaction::class, - SerializedBytes::class, - TimestampCommand::class + SerializedBytes::class ) immutables.forEach { diff --git a/src/test/kotlin/contracts/CommercialPaperTests.kt b/src/test/kotlin/contracts/CommercialPaperTests.kt index c1f1177c63..26a173ecaa 100644 --- a/src/test/kotlin/contracts/CommercialPaperTests.kt +++ b/src/test/kotlin/contracts/CommercialPaperTests.kt @@ -9,6 +9,7 @@ package contracts import core.* +import core.node.TimestampingError import core.testutils.* import org.junit.Test import java.time.Clock @@ -81,7 +82,7 @@ class CommercialPaperTests { CommercialPaper().craftIssue(MINI_CORP.ref(123), 10000.DOLLARS, TEST_TX_TIME + 30.days).apply { setTime(TEST_TX_TIME, DummyTimestampingAuthority.identity, 30.seconds) signWith(MINI_CORP_KEY) - assertFailsWith(NotOnTimeException::class) { + assertFailsWith(TimestampingError.NotOnTimeException::class) { timestamp(DummyTimestamper(Clock.fixed(TEST_TX_TIME + 5.hours, ZoneOffset.UTC))) } } @@ -89,7 +90,7 @@ class CommercialPaperTests { CommercialPaper().craftIssue(MINI_CORP.ref(123), 10000.DOLLARS, TEST_TX_TIME + 30.days).apply { setTime(TEST_TX_TIME, DummyTimestampingAuthority.identity, 30.seconds) signWith(MINI_CORP_KEY) - assertFailsWith(NotOnTimeException::class) { + assertFailsWith(TimestampingError.NotOnTimeException::class) { val tsaClock = Clock.fixed(TEST_TX_TIME - 5.hours, ZoneOffset.UTC) timestamp(DummyTimestamper(tsaClock), Clock.fixed(TEST_TX_TIME, ZoneOffset.UTC)) } diff --git a/src/test/kotlin/core/MockServices.kt b/src/test/kotlin/core/MockServices.kt new file mode 100644 index 0000000000..7aeff3ff63 --- /dev/null +++ b/src/test/kotlin/core/MockServices.kt @@ -0,0 +1,91 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core + +import core.messaging.MessagingService +import core.node.TimestampingError +import core.serialization.SerializedBytes +import core.serialization.deserialize +import core.testutils.TEST_KEYS_TO_CORP_MAP +import core.testutils.TEST_TX_TIME +import java.security.KeyPair +import java.security.KeyPairGenerator +import java.security.PrivateKey +import java.security.PublicKey +import java.time.Clock +import java.time.Duration +import java.time.ZoneId +import java.util.* +import javax.annotation.concurrent.ThreadSafe + +/** + * A test/mock timestamping service that doesn't use any signatures or security. It timestamps with + * the provided clock which defaults to [TEST_TX_TIME], an arbitrary point on the timeline. + */ +class DummyTimestamper(var clock: Clock = Clock.fixed(TEST_TX_TIME, ZoneId.systemDefault()), + val tolerance: Duration = 30.seconds) : TimestamperService { + override val identity = DummyTimestampingAuthority.identity + + override fun timestamp(wtxBytes: SerializedBytes): DigitalSignature.LegallyIdentifiable { + val wtx = wtxBytes.deserialize() + val timestamp = wtx.commands.mapNotNull { it.data as? TimestampCommand }.single() + if (timestamp.before!! until clock.instant() > tolerance) + throw TimestampingError.NotOnTimeException() + return DummyTimestampingAuthority.key.signWithECDSA(wtxBytes.bits, identity) + } +} + +val DUMMY_TIMESTAMPER = DummyTimestamper() + +object MockIdentityService : IdentityService { + override fun partyFromKey(key: PublicKey): Party? = TEST_KEYS_TO_CORP_MAP[key] +} + +class MockKeyManagementService( + override val keys: Map, + val nextKeys: MutableList = arrayListOf(KeyPairGenerator.getInstance("EC").genKeyPair()) +) : KeyManagementService { + override fun freshKey() = nextKeys.removeAt(nextKeys.lastIndex) +} + +class MockWalletService(val states: List>) : WalletService { + override val currentWallet = Wallet(states) +} + +@ThreadSafe +class MockStorageService : StorageService { + override val myLegalIdentityKey: KeyPair = KeyPairGenerator.getInstance("EC").genKeyPair() + override val myLegalIdentity: Party = Party("Unit test party", myLegalIdentityKey.public) + + private val mapOfMaps = HashMap>() + + @Synchronized + override fun getMap(tableName: String): MutableMap { + return mapOfMaps.getOrPut(tableName) { Collections.synchronizedMap(HashMap()) } as MutableMap + } +} + +class MockServices( + val wallet: WalletService? = null, + val keyManagement: KeyManagementService? = null, + val net: MessagingService? = null, + val identity: IdentityService? = MockIdentityService, + val storage: StorageService? = MockStorageService() +) : ServiceHub { + override val walletService: WalletService + get() = wallet ?: throw UnsupportedOperationException() + override val keyManagementService: KeyManagementService + get() = keyManagement ?: throw UnsupportedOperationException() + override val identityService: IdentityService + get() = identity ?: throw UnsupportedOperationException() + override val networkService: MessagingService + get() = net ?: throw UnsupportedOperationException() + override val storageService: StorageService + get() = storage ?: throw UnsupportedOperationException() +} diff --git a/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt b/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt index db3b212918..5d4b523095 100644 --- a/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt +++ b/src/test/kotlin/core/messaging/InMemoryMessagingTests.kt @@ -43,7 +43,7 @@ open class TestWithInMemoryNetwork { network.stop() } - fun pumpAll(blocking: Boolean) = nodes.values.map { it.pump(blocking) } + fun pumpAll(blocking: Boolean) = network.nodes.map { it.pump(blocking) } // Keep calling "pump" in rounds until every node in the network reports that it had nothing to do fun runNetwork(body: () -> T): T { diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index d7c9629702..792b73f779 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -62,16 +62,17 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { val (alicesAddress, alicesNode) = makeNode(inBackground = true) val (bobsAddress, bobsNode) = makeNode(inBackground = true) + val timestamper = network.setupTimestampingNode(false).first - val alicesServices = MockServices(wallet = null, keyManagement = null, net = alicesNode) + val alicesServices = MockServices(net = alicesNode) val bobsServices = MockServices( wallet = MockWalletService(bobsWallet), keyManagement = MockKeyManagementService(mapOf(BOB to BOB_KEY.private)), net = bobsNode ) - val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread)) - val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread)) + val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, backgroundThread), timestamper) + val tpBuyer = TwoPartyTradeProtocol.create(StateMachineManager(bobsServices, backgroundThread), timestamper) val buyerSessionID = random63BitValue() @@ -115,6 +116,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { val (alicesAddress, alicesNode) = makeNode(inBackground = false) var (bobsAddress, bobsNode) = makeNode(inBackground = false) + val timestamper = network.setupTimestampingNode(true) val bobsStorage = MockStorageService() @@ -126,9 +128,9 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { storage = bobsStorage ) - val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, MoreExecutors.directExecutor())) + val tpSeller = TwoPartyTradeProtocol.create(StateMachineManager(alicesServices, MoreExecutors.directExecutor()), timestamper.first) val smmBuyer = StateMachineManager(bobsServices, MoreExecutors.directExecutor()) - val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer) + val tpBuyer = TwoPartyTradeProtocol.create(smmBuyer, timestamper.first) val buyerSessionID = random63BitValue() @@ -161,9 +163,11 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk. bobsNode.stop() - // Alice doesn't know that and sends Bob the now finalised transaction. Alice sends a message to a node - // that has gone offline. - alicesNode.pump(false) + // Alice doesn't know that and carries on: first timestamping and then sending Bob the now finalised + // transaction. Alice sends a message to a node that has gone offline. + assertTrue(alicesNode.pump(false)) + assertTrue(timestamper.second.pump(false)) + assertTrue(alicesNode.pump(false)) // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // that Bob was waiting on before the reboot occurred. diff --git a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt new file mode 100644 index 0000000000..dcb6ed2d2d --- /dev/null +++ b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt @@ -0,0 +1,131 @@ +/* + * Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members + * pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms + * set forth therein. + * + * All other rights reserved. + */ + +package core.node + +import co.paralleluniverse.fibers.Suspendable +import core.* +import core.messaging.* +import core.serialization.serialize +import core.testutils.ALICE +import core.testutils.ALICE_KEY +import core.testutils.CASH +import core.utilities.BriefLogFormatter +import org.junit.Before +import org.junit.Test +import java.security.PublicKey +import java.time.Clock +import java.time.Instant +import java.time.ZoneId +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class TimestamperNodeServiceTest : TestWithInMemoryNetwork() { + lateinit var myNode: Pair + lateinit var serviceNode: Pair + lateinit var service: TimestamperNodeService + + val ptx = TransactionBuilder().apply { + addInputState(ContractStateRef(SecureHash.randomSHA256(), 0)) + addOutputState(100.DOLLARS.CASH) + } + + val clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()) + lateinit var mockServices: ServiceHub + lateinit var serverKey: PublicKey + + init { + BriefLogFormatter.initVerbose("dlg.timestamping.request") + } + + @Before + fun setup() { + myNode = makeNode() + serviceNode = makeNode() + mockServices = MockServices(net = serviceNode.second, storage = MockStorageService()) + serverKey = network.setupTimestampingNode(true).first.identity.owningKey + + // And a separate one to be tested directly, to make the unit tests a bit faster. + service = TimestamperNodeService(serviceNode.second, Party("Unit test suite", ALICE), ALICE_KEY) + } + + class TestPSM(val server: LegallyIdentifiableNode, val now: Instant) : ProtocolStateMachine() { + @Suspendable + override fun call(args: Any?): Boolean { + val client = TimestamperClient(this, server) + val ptx = TransactionBuilder().apply { + addInputState(ContractStateRef(SecureHash.randomSHA256(), 0)) + addOutputState(100.DOLLARS.CASH) + } + ptx.addCommand(TimestampCommand(now - 20.seconds, now + 20.seconds), server.identity.owningKey) + val wtx = ptx.toWireTransaction() + // This line will invoke sendAndReceive to interact with the network. + val sig = client.timestamp(wtx.serialize()) + ptx.checkAndAddSignature(sig) + return true + } + } + + @Test + fun successWithNetwork() { + val psm = runNetwork { + val smm = StateMachineManager(MockServices(net = myNode.second), RunOnCallerThread) + val logName = TimestamperNodeService.TIMESTAMPING_PROTOCOL_TOPIC + val psm = TestPSM(myNode.second.networkMap.timestampingNodes[0], clock.instant()) + smm.add(serviceNode.first, logName, psm) + psm + } + assertTrue(psm.isDone) + } + + @Test + fun wrongCommands() { + // Zero commands is not OK. + assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) { + val wtx = ptx.toWireTransaction() + service.processRequest(TimestampingMessages.Request(wtx.serialize(), myNode.first, "ignored")) + } + // More than one command is not OK. + assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) { + ptx.addCommand(TimestampCommand(clock.instant(), 30.seconds), ALICE) + ptx.addCommand(TimestampCommand(clock.instant(), 40.seconds), ALICE) + val wtx = ptx.toWireTransaction() + service.processRequest(TimestampingMessages.Request(wtx.serialize(), myNode.first, "ignored")) + } + } + + @Test + fun tooEarly() { + assertFailsWith(TimestampingError.NotOnTimeException::class) { + val now = clock.instant() + ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE) + val wtx = ptx.toWireTransaction() + service.processRequest(TimestampingMessages.Request(wtx.serialize(), myNode.first, "ignored")) + } + } + + @Test + fun tooLate() { + assertFailsWith(TimestampingError.NotOnTimeException::class) { + val now = clock.instant() + ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE) + val wtx = ptx.toWireTransaction() + service.processRequest(TimestampingMessages.Request(wtx.serialize(), myNode.first, "ignored")) + } + } + + @Test + fun success() { + val now = clock.instant() + ptx.addCommand(TimestampCommand(now - 20.seconds, now + 20.seconds), ALICE) + val wtx = ptx.toWireTransaction() + val sig = service.processRequest(TimestampingMessages.Request(wtx.serialize(), myNode.first, "ignored")) + ptx.checkAndAddSignature(sig) + ptx.toSignedTransaction(false).verifySignatures() + } +} \ No newline at end of file diff --git a/src/test/kotlin/core/testutils/TestUtils.kt b/src/test/kotlin/core/testutils/TestUtils.kt index 73e167eef7..ea8271977a 100644 --- a/src/test/kotlin/core/testutils/TestUtils.kt +++ b/src/test/kotlin/core/testutils/TestUtils.kt @@ -12,20 +12,11 @@ package core.testutils import contracts.* import core.* -import core.messaging.MessagingService -import core.serialization.SerializedBytes -import core.serialization.deserialize import core.visualiser.GraphVisualiser -import java.security.KeyPair import java.security.KeyPairGenerator -import java.security.PrivateKey import java.security.PublicKey -import java.time.Clock -import java.time.Duration import java.time.Instant -import java.time.ZoneId import java.util.* -import javax.annotation.concurrent.ThreadSafe import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.fail @@ -66,72 +57,6 @@ val TEST_PROGRAM_MAP: Map = mapOf( DUMMY_PROGRAM_ID to DummyContract ) -/** - * A test/mock timestamping service that doesn't use any signatures or security. It timestamps with - * the provided clock which defaults to [TEST_TX_TIME], an arbitrary point on the timeline. - */ -class DummyTimestamper(var clock: Clock = Clock.fixed(TEST_TX_TIME, ZoneId.systemDefault()), - val tolerance: Duration = 30.seconds) : TimestamperService { - override val identity = DummyTimestampingAuthority.identity - - override fun timestamp(wtxBytes: SerializedBytes): DigitalSignature.LegallyIdentifiable { - val wtx = wtxBytes.deserialize() - val timestamp = wtx.commands.mapNotNull { it.data as? TimestampCommand }.single() - if (Duration.between(timestamp.before, clock.instant()) > tolerance) - throw NotOnTimeException() - return DummyTimestampingAuthority.key.signWithECDSA(wtxBytes.bits, identity) - } -} - -val DUMMY_TIMESTAMPER = DummyTimestamper() - -object MockIdentityService : IdentityService { - override fun partyFromKey(key: PublicKey): Party? = TEST_KEYS_TO_CORP_MAP[key] -} - -class MockKeyManagementService( - override val keys: Map, - val nextKeys: MutableList = arrayListOf(KeyPairGenerator.getInstance("EC").genKeyPair()) -) : KeyManagementService { - override fun freshKey() = nextKeys.removeAt(nextKeys.lastIndex) -} - -class MockWalletService(val states: List>) : WalletService { - override val currentWallet = Wallet(states) -} - -@ThreadSafe -class MockStorageService : StorageService { - private val mapOfMaps = HashMap>() - - @Synchronized - override fun getMap(tableName: String): MutableMap { - return mapOfMaps.getOrPut(tableName) { Collections.synchronizedMap(HashMap()) } as MutableMap - } -} - -class MockServices( - val wallet: WalletService?, - val keyManagement: KeyManagementService?, - val net: MessagingService?, - val identity: IdentityService? = MockIdentityService, - val storage: StorageService? = MockStorageService(), - val timestamping: TimestamperService? = DUMMY_TIMESTAMPER -) : ServiceHub { - override val walletService: WalletService - get() = wallet ?: throw UnsupportedOperationException() - override val keyManagementService: KeyManagementService - get() = keyManagement ?: throw UnsupportedOperationException() - override val identityService: IdentityService - get() = identity ?: throw UnsupportedOperationException() - override val timestampingService: TimestamperService - get() = timestamping ?: throw UnsupportedOperationException() - override val networkService: MessagingService - get() = net ?: throw UnsupportedOperationException() - override val storageService: StorageService - get() = storage ?: throw UnsupportedOperationException() -} - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // // Defines a simple DSL for building pseudo-transactions (not the same as the wire protocol) for testing purposes.