From 6c96517f6fec356ef0da8087a6e8a323536de141 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Thu, 22 Sep 2016 13:49:45 +0100 Subject: [PATCH] core, node: Add RPC calls, change RPC init order --- .../r3corda/core/node/services/Services.kt | 19 +++-- .../core/node/services/TransactionStorage.kt | 10 ++- .../r3corda/core/protocols/ProtocolLogic.kt | 9 ++- .../core/testing/InMemoryVaultService.kt | 17 ++-- .../kotlin/com/r3corda/node/driver/Driver.kt | 4 +- .../com/r3corda/node/internal/AbstractNode.kt | 10 ++- .../kotlin/com/r3corda/node/internal/Node.kt | 13 ++-- .../com/r3corda/node/internal/ServerRPCOps.kt | 30 ++++++- .../node/services/messaging/CordaRPCOps.kt | 67 +++++++++++++++- .../services/messaging/NodeMessagingClient.kt | 12 +-- .../persistence/PerFileTransactionStorage.kt | 46 +++++++---- .../statemachine/StateMachineManager.kt | 2 +- .../node/services/vault/NodeVaultService.kt | 78 ++++++++++--------- .../messaging/TwoPartyTradeProtocolTests.kt | 4 +- .../com/r3corda/testing/node/MockNode.kt | 3 +- .../com/r3corda/testing/node/MockServices.kt | 6 +- 16 files changed, 237 insertions(+), 93 deletions(-) diff --git a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt index e43d7b19ed..60a8fc1e7e 100644 --- a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt +++ b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt @@ -5,6 +5,7 @@ import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.contracts.* import com.r3corda.core.crypto.Party import com.r3corda.core.transactions.WireTransaction +import rx.Observable import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey @@ -93,6 +94,18 @@ interface VaultService { */ val currentVault: Vault + /** + * Get a synchronous Observable of updates. When observations are pushed to the Observer, the Vault will already incorporate + * the update. + */ + val updates: Observable + + /** + * Atomically get the current vault and a stream of updates. Note that the Observable buffers updates until the + * first subscriber is registered so as to avoid racing with early updates. + */ + fun track(): Pair> + /** * Returns a snapshot of the heads of LinearStates. */ @@ -124,12 +137,6 @@ interface VaultService { /** Same as notifyAll but with a single transaction. */ fun notify(tx: WireTransaction): Vault = notifyAll(listOf(tx)) - /** - * Get a synchronous Observable of updates. When observations are pushed to the Observer, the vault will already - * incorporate the update. - */ - val updates: rx.Observable - /** * Provide a [Future] for when a [StateRef] is consumed, which can be very useful in building tests. */ diff --git a/core/src/main/kotlin/com/r3corda/core/node/services/TransactionStorage.kt b/core/src/main/kotlin/com/r3corda/core/node/services/TransactionStorage.kt index cd5c455fe9..d231d2bf45 100644 --- a/core/src/main/kotlin/com/r3corda/core/node/services/TransactionStorage.kt +++ b/core/src/main/kotlin/com/r3corda/core/node/services/TransactionStorage.kt @@ -1,7 +1,8 @@ package com.r3corda.core.node.services -import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.transactions.SignedTransaction +import rx.Observable /** * Thread-safe storage of transactions. @@ -16,7 +17,12 @@ interface ReadOnlyTransactionStorage { * Get a synchronous Observable of updates. When observations are pushed to the Observer, the vault will already * incorporate the update. */ - val updates: rx.Observable + val updates: Observable + + /** + * Returns all currently stored transactions and further fresh ones. + */ + fun track(): Pair, Observable> } /** diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt index 29ce07e7a4..bd01afe48e 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt @@ -9,6 +9,7 @@ import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.debug import com.r3corda.protocols.HandshakeMessage import org.slf4j.Logger +import rx.Observable import java.util.* /** @@ -158,4 +159,10 @@ abstract class ProtocolLogic { private data class Session(val sendSessionId: Long, val receiveSessionId: Long) -} \ No newline at end of file + // TODO this is not threadsafe, needs an atomic get-step-and-subscribe + fun track(): Pair>? { + return progressTracker?.let { + Pair(it.currentStep.toString(), it.changes.map { it.toString() }) + } + } +} diff --git a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryVaultService.kt b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryVaultService.kt index 379f9e7d20..f17c540f0d 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryVaultService.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryVaultService.kt @@ -1,8 +1,8 @@ package com.r3corda.core.testing import com.r3corda.core.ThreadBox +import com.r3corda.core.bufferUntilSubscribed import com.r3corda.core.contracts.* -import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.Vault import com.r3corda.core.node.services.VaultService @@ -30,16 +30,21 @@ open class InMemoryVaultService(protected val services: ServiceHub) : SingletonS // to vault somewhere. protected class InnerState { var vault = Vault(emptyList>()) + val _updatesPublisher = PublishSubject.create() } protected val mutex = ThreadBox(InnerState()) override val currentVault: Vault get() = mutex.locked { vault } - private val _updatesPublisher = PublishSubject.create() - override val updates: Observable - get() = _updatesPublisher + get() = mutex.content._updatesPublisher + + override fun track(): Pair> { + return mutex.locked { + Pair(vault, updates.bufferUntilSubscribed()) + } + } /** * Returns a snapshot of the heads of LinearStates. @@ -82,7 +87,9 @@ open class InMemoryVaultService(protected val services: ServiceHub) : SingletonS } if (netDelta != Vault.NoUpdate) { - _updatesPublisher.onNext(netDelta) + mutex.locked { + _updatesPublisher.onNext(netDelta) + } } return changedVault } diff --git a/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt index 3a30fa7ae6..cfce8ee824 100644 --- a/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt +++ b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt @@ -3,7 +3,6 @@ package com.r3corda.node.driver import com.google.common.net.HostAndPort import com.r3corda.core.ThreadBox import com.r3corda.core.crypto.Party -import com.r3corda.core.crypto.X509Utilities import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.services.NetworkMapCache @@ -66,6 +65,7 @@ interface DriverDSLExposedInterface { * @param serverAddress the artemis server to connect to, for example a [Node]. */ fun startClient(providedName: String, serverAddress: HostAndPort): Future + /** * Starts a local [ArtemisMessagingServer] of which there may only be one. */ @@ -345,7 +345,7 @@ class DriverDSL( return Executors.newSingleThreadExecutor().submit(Callable { client.configureWithDevSSLCertificate() - client.start() + client.start(null) thread { client.run() } state.locked { clients.add(client) diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 304de8a0ef..337abb98c1 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -31,13 +31,17 @@ import com.r3corda.node.services.events.NodeSchedulerService import com.r3corda.node.services.events.ScheduledActivityObserver import com.r3corda.node.services.identity.InMemoryIdentityService import com.r3corda.node.services.keys.PersistentKeyManagementService +import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.monitor.NodeMonitorService import com.r3corda.node.services.network.InMemoryNetworkMapCache import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.services.network.PersistentNetworkMapService -import com.r3corda.node.services.persistence.* +import com.r3corda.node.services.persistence.NodeAttachmentService +import com.r3corda.node.services.persistence.PerFileCheckpointStorage +import com.r3corda.node.services.persistence.PerFileTransactionStorage +import com.r3corda.node.services.persistence.StorageServiceImpl import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.node.services.transactions.NotaryService import com.r3corda.node.services.transactions.SimpleNotaryService @@ -225,7 +229,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap ScheduledActivityObserver(services) } - startMessagingService() + startMessagingService(ServerRPCOps(services, smm, database)) runOnStop += Runnable { net.stop() } _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any() @@ -412,7 +416,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap protected abstract fun makeMessagingService(): MessagingServiceInternal - protected abstract fun startMessagingService() + protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?) protected open fun initialiseStorageService(dir: Path): Pair { val attachments = makeAttachmentStorage(dir) diff --git a/node/src/main/kotlin/com/r3corda/node/internal/Node.kt b/node/src/main/kotlin/com/r3corda/node/internal/Node.kt index 7edf11f73d..efbc8cacb3 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/Node.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/Node.kt @@ -11,6 +11,7 @@ import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.config.FullNodeConfiguration import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.messaging.ArtemisMessagingServer +import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.transactions.PersistentUniquenessProvider import com.r3corda.node.servlets.AttachmentDownloadServlet @@ -123,14 +124,12 @@ class Node(val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, messageBroker = ArtemisMessagingServer(configuration, p2pAddr, services.networkMapCache) p2pAddr }() - val ops = ServerRPCOps(services) val myIdentityOrNullIfNetworkMapService = if (networkMapService != null) services.storageService.myLegalIdentityKey.public else null return NodeMessagingClient(configuration, serverAddr, myIdentityOrNullIfNetworkMapService, serverThread, - persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }, - rpcOps = ops) + persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }) } - override fun startMessagingService() { + override fun startMessagingService(cordaRPCOps: CordaRPCOps?) { // Start up the embedded MQ server messageBroker?.apply { runOnStop += Runnable { messageBroker?.stop() } @@ -139,9 +138,9 @@ class Node(val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, } // Start up the MQ client. - (net as NodeMessagingClient).apply { - start() - } + val net = net as NodeMessagingClient + net.configureWithDevSSLCertificate() // TODO: Client might need a separate certificate + net.start(cordaRPCOps) } private fun initWebServer(): Server { diff --git a/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt b/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt index 711e3bdfa2..a4bfcb378d 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/ServerRPCOps.kt @@ -1,14 +1,40 @@ package com.r3corda.node.internal +import com.r3corda.core.contracts.ContractState +import com.r3corda.core.contracts.StateAndRef +import com.r3corda.core.node.services.Vault import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.messaging.CordaRPCOps +import com.r3corda.node.services.messaging.StateMachineInfo +import com.r3corda.node.services.messaging.StateMachineUpdate +import com.r3corda.node.services.statemachine.StateMachineManager +import com.r3corda.node.utilities.databaseTransaction +import org.jetbrains.exposed.sql.Database +import rx.Observable /** * Server side implementations of RPCs available to MQ based client tools. Execution takes place on the server * thread (i.e. serially). Arguments are serialised and deserialised automatically. */ -class ServerRPCOps(services: ServiceHubInternal) : CordaRPCOps { +class ServerRPCOps( + val services: ServiceHubInternal, + val stateMachineManager: StateMachineManager, + val database: Database +) : CordaRPCOps { override val protocolVersion: Int = 0 - // TODO: Add useful RPCs for client apps (examining the vault, etc) + override fun vaultAndUpdates(): Pair>, Observable> { + return databaseTransaction(database) { + val (vault, updates) = services.vaultService.track() + Pair(vault.states.toList(), updates) + } + } + override fun verifiedTransactions() = services.storageService.validatedTransactions.track() + override fun stateMachinesAndUpdates(): Pair, Observable> { + val (allStateMachines, changes) = stateMachineManager.track() + return Pair( + allStateMachines.map { StateMachineInfo.fromProtocolStateMachineImpl(it) }, + changes.map { StateMachineUpdate.fromStateMachineChange(it) } + ) + } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/messaging/CordaRPCOps.kt b/node/src/main/kotlin/com/r3corda/node/services/messaging/CordaRPCOps.kt index 8c91bb7055..30f9e255ec 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/messaging/CordaRPCOps.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/messaging/CordaRPCOps.kt @@ -1,11 +1,74 @@ package com.r3corda.node.services.messaging +import com.r3corda.core.contracts.ContractState +import com.r3corda.core.contracts.StateAndRef +import com.r3corda.core.node.services.Vault +import com.r3corda.core.protocols.StateMachineRunId +import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl +import com.r3corda.node.services.statemachine.StateMachineManager +import com.r3corda.node.utilities.AddOrRemove import rx.Observable +data class StateMachineInfo( + val id: StateMachineRunId, + val protocolLogicClassName: String, + val progressTrackerStepAndUpdates: Pair>? +) { + companion object { + fun fromProtocolStateMachineImpl(psm: ProtocolStateMachineImpl<*>): StateMachineInfo { + return StateMachineInfo( + id = psm.id, + protocolLogicClassName = psm.logic.javaClass.simpleName, + progressTrackerStepAndUpdates = psm.logic.track() + ) + } + } +} + +sealed class StateMachineUpdate { + class Added(val stateMachineInfo: StateMachineInfo) : StateMachineUpdate() + class Removed(val stateMachineRunId: StateMachineRunId) : StateMachineUpdate() + + companion object { + fun fromStateMachineChange(change: StateMachineManager.Change): StateMachineUpdate { + return when (change.addOrRemove) { + AddOrRemove.ADD -> { + val stateMachineInfo = StateMachineInfo( + id = change.id, + protocolLogicClassName = change.logic.javaClass.simpleName, + progressTrackerStepAndUpdates = change.logic.track() + ) + StateMachineUpdate.Added(stateMachineInfo) + } + AddOrRemove.REMOVE -> { + StateMachineUpdate.Removed(change.id) + } + } + } + } +} + /** * RPC operations that the node exposes to clients using the Java client library. These can be called from * client apps and are implemented by the node in the [ServerRPCOps] class. */ interface CordaRPCOps : RPCOps { - // TODO: Add useful RPCs for client apps (examining the vault, etc) -} \ No newline at end of file + /** + * Returns a pair of currently in-progress state machine infos and an observable of future state machine adds/removes. + */ + @RPCReturnsObservables + fun stateMachinesAndUpdates(): Pair, Observable> + + /** + * Returns a pair of head states in the vault and an observable of future updates to the vault. + */ + @RPCReturnsObservables + fun vaultAndUpdates(): Pair>, Observable> + + /** + * Returns a pair of all recorded transactions and an observable of future recorded ones. + */ + @RPCReturnsObservables + fun verifiedTransactions(): Pair, Observable> +} diff --git a/node/src/main/kotlin/com/r3corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/com/r3corda/node/services/messaging/NodeMessagingClient.kt index 52e310565c..f8869f75fe 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/messaging/NodeMessagingClient.kt @@ -54,8 +54,7 @@ class NodeMessagingClient(config: NodeConfiguration, val myIdentity: PublicKey?, val executor: AffinityExecutor, val persistentInbox: Boolean = true, - val persistenceTx: (() -> Unit) -> Unit = { it() }, - private val rpcOps: CordaRPCOps? = null) : ArtemisMessagingComponent(config), MessagingServiceInternal { + val persistenceTx: (() -> Unit) -> Unit = { it() }) : ArtemisMessagingComponent(config), MessagingServiceInternal { companion object { val log = loggerFor() @@ -113,7 +112,7 @@ class NodeMessagingClient(config: NodeConfiguration, require(config.basedir.fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" } } - fun start() { + fun start(rpcOps: CordaRPCOps? = null) { state.locked { check(!started) { "start can't be called twice" } started = true @@ -150,6 +149,7 @@ class NodeMessagingClient(config: NodeConfiguration, session.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 1") rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE) rpcNotificationConsumer = session.createConsumer("rpc.qremovals") + dispatcher = createRPCDispatcher(state, rpcOps) } } } @@ -392,7 +392,9 @@ class NodeMessagingClient(config: NodeConfiguration, } } - private fun createRPCDispatcher(ops: CordaRPCOps) = object : RPCDispatcher(ops) { + var dispatcher: RPCDispatcher? = null + + private fun createRPCDispatcher(state: ThreadBox, ops: CordaRPCOps) = object : RPCDispatcher(ops) { override fun send(bits: SerializedBytes<*>, toAddress: String) { state.locked { val msg = session!!.createMessage(false).apply { @@ -404,6 +406,4 @@ class NodeMessagingClient(config: NodeConfiguration, } } } - - private val dispatcher = if (rpcOps != null) createRPCDispatcher(rpcOps) else null } diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileTransactionStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileTransactionStorage.kt index 8acb1ba197..0e320cbabf 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileTransactionStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileTransactionStorage.kt @@ -1,17 +1,19 @@ package com.r3corda.node.services.persistence -import com.r3corda.core.transactions.SignedTransaction +import com.r3corda.core.ThreadBox +import com.r3corda.core.bufferUntilSubscribed import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.services.TransactionStorage import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.serialize +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.trace import rx.Observable import rx.subjects.PublishSubject import java.nio.file.Files import java.nio.file.Path -import java.util.concurrent.ConcurrentHashMap +import java.util.* import javax.annotation.concurrent.ThreadSafe /** @@ -19,40 +21,50 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe class PerFileTransactionStorage(val storeDir: Path) : TransactionStorage { - companion object { private val logger = loggerFor() private val fileExtension = ".transaction" } - private val _transactions = ConcurrentHashMap() + private val mutex = ThreadBox(object { + val transactionsMap = HashMap() + val updatesPublisher = PublishSubject.create() - private val _updatesPublisher = PublishSubject.create() + fun notify(transaction: SignedTransaction) = updatesPublisher.onNext(transaction) + }) override val updates: Observable - get() = _updatesPublisher - - private fun notify(transaction: SignedTransaction) = _updatesPublisher.onNext(transaction) + get() = mutex.content.updatesPublisher init { logger.trace { "Initialising per file transaction storage on $storeDir" } Files.createDirectories(storeDir) - Files.list(storeDir) - .filter { it.toString().toLowerCase().endsWith(fileExtension) } - .map { Files.readAllBytes(it).deserialize() } - .forEach { _transactions[it.id] = it } + mutex.locked { + Files.list(storeDir) + .filter { it.toString().toLowerCase().endsWith(fileExtension) } + .map { Files.readAllBytes(it).deserialize() } + .forEach { transactionsMap[it.id] = it } + } } override fun addTransaction(transaction: SignedTransaction) { val transactionFile = storeDir.resolve("${transaction.id.toString().toLowerCase()}$fileExtension") transaction.serialize().writeToFile(transactionFile) - _transactions[transaction.id] = transaction + mutex.locked { + transactionsMap[transaction.id] = transaction + notify(transaction) + } logger.trace { "Stored $transaction to $transactionFile" } - notify(transaction) } - override fun getTransaction(id: SecureHash): SignedTransaction? = _transactions[id] + override fun getTransaction(id: SecureHash): SignedTransaction? = mutex.locked { transactionsMap[id] } - val transactions: Iterable get() = _transactions.values.toList() + val transactions: Iterable get() = mutex.locked { transactionsMap.values.toList() } -} \ No newline at end of file + override fun track(): Pair, Observable> { + return mutex.locked { + Pair(transactionsMap.values.toList(), updates.bufferUntilSubscribed()) + } + } + +} diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index 8b3aacf6b9..bd85ca7235 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -123,7 +123,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * calls to [allStateMachines] */ - fun getAllStateMachinesAndChanges(): Pair>, Observable> { + fun track(): Pair>, Observable> { return mutex.locked { val bufferedChanges = UnicastSubject.create() changesPublisher.subscribe(bufferedChanges) diff --git a/node/src/main/kotlin/com/r3corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/com/r3corda/node/services/vault/NodeVaultService.kt index ef988ff0ef..085daf1f93 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/vault/NodeVaultService.kt @@ -1,6 +1,8 @@ package com.r3corda.node.services.vault import com.google.common.collect.Sets +import com.r3corda.core.ThreadBox +import com.r3corda.core.bufferUntilSubscribed import com.r3corda.core.contracts.* import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.ServiceHub @@ -17,8 +19,6 @@ import org.jetbrains.exposed.sql.statements.InsertStatement import rx.Observable import rx.subjects.PublishSubject import java.security.PublicKey -import java.util.concurrent.locks.ReentrantLock -import kotlin.concurrent.withLock /** * Currently, the node vault service is a very simple RDBMS backed implementation. It will change significantly when @@ -42,23 +42,48 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT val index = integer("output_index") } - private val unconsumedStates = object : AbstractJDBCHashSet(StatesSetTable) { - override fun elementFromRow(it: ResultRow): StateRef = StateRef(SecureHash.SHA256(it[table.txhash]), it[table.index]) + private val mutex = ThreadBox(object { + val unconsumedStates = object : AbstractJDBCHashSet(StatesSetTable) { + override fun elementFromRow(it: ResultRow): StateRef = StateRef(SecureHash.SHA256(it[table.txhash]), it[table.index]) - override fun addElementToInsert(insert: InsertStatement, entry: StateRef, finalizables: MutableList<() -> Unit>) { - insert[table.txhash] = entry.txhash.bits - insert[table.index] = entry.index + override fun addElementToInsert(it: InsertStatement, entry: StateRef, finalizables: MutableList<() -> Unit>) { + it[table.txhash] = entry.txhash.bits + it[table.index] = entry.index + } } - } + val _updatesPublisher = PublishSubject.create() - protected val mutex = ReentrantLock() + fun allUnconsumedStates(): Iterable> { + // Order by txhash for if and when transaction storage has some caching. + // Map to StateRef and then to StateAndRef. Use Sequence to avoid conversion to ArrayList that Iterable.map() performs. + return unconsumedStates.asSequence().map { + val storedTx = services.storageService.validatedTransactions.getTransaction(it.txhash) ?: throw Error("Found transaction hash ${it.txhash} in unconsumed contract states that is not in transaction storage.") + StateAndRef(storedTx.tx.outputs[it.index], it) + }.asIterable() + } - override val currentVault: Vault get() = mutex.withLock { Vault(allUnconsumedStates()) } + fun recordUpdate(update: Vault.Update): Vault.Update { + if (update != Vault.NoUpdate) { + val producedStateRefs = update.produced.map { it.ref } + val consumedStateRefs = update.consumed + log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } + unconsumedStates.removeAll(consumedStateRefs) + unconsumedStates.addAll(producedStateRefs) + } + return update + } + }) - private val _updatesPublisher = PublishSubject.create() + override val currentVault: Vault get() = mutex.locked { Vault(allUnconsumedStates()) } override val updates: Observable - get() = _updatesPublisher + get() = mutex.locked { _updatesPublisher } + + override fun track(): Pair> { + return mutex.locked { + Pair(Vault(allUnconsumedStates()), _updatesPublisher.bufferUntilSubscribed()) + } + } /** * Returns a snapshot of the heads of LinearStates. @@ -72,10 +97,9 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT val ourKeys = services.keyManagementService.keys.keys val netDelta = txns.fold(Vault.NoUpdate) { netDelta, txn -> netDelta + makeUpdate(txn, netDelta, ourKeys) } if (netDelta != Vault.NoUpdate) { - mutex.withLock { + mutex.locked { recordUpdate(netDelta) } - _updatesPublisher.onNext(netDelta) } return currentVault } @@ -91,7 +115,9 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT // i.e. retainAll() iterates over consumed, checking contains() on the parameter. Sets.union() does not physically create // a new collection and instead contains() just checks the contains() of both parameters, and so we don't end up // iterating over all (a potentially very large) unconsumedStates at any point. - consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates)) + mutex.locked { + consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates)) + } // Is transaction irrelevant? if (consumed.isEmpty() && ourNewStates.isEmpty()) { @@ -112,24 +138,4 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT false } } - - private fun recordUpdate(update: Vault.Update): Vault.Update { - if (update != Vault.NoUpdate) { - val producedStateRefs = update.produced.map { it.ref } - val consumedStateRefs = update.consumed - log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } - unconsumedStates.removeAll(consumedStateRefs) - unconsumedStates.addAll(producedStateRefs) - } - return update - } - - private fun allUnconsumedStates(): Iterable> { - // Order by txhash for if and when transaction storage has some caching. - // Map to StateRef and then to StateAndRef. Use Sequence to avoid conversion to ArrayList that Iterable.map() performs. - return unconsumedStates.asSequence().map { - val storedTx = services.storageService.validatedTransactions.getTransaction(it.txhash) ?: throw Error("Found transaction hash ${it.txhash} in unconsumed contract states that is not in transaction storage.") - StateAndRef(storedTx.tx.outputs[it.index], it) - }.asIterable() - } -} \ No newline at end of file +} diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index 4ac574c4bb..ee046fed2f 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -35,7 +35,6 @@ import org.junit.Test import rx.Observable import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream -import java.nio.file.Path import java.security.KeyPair import java.security.PublicKey import java.util.* @@ -437,6 +436,9 @@ class TwoPartyTradeProtocolTests { } class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage { + override fun track(): Pair, Observable> { + return delegate.track() + } val records: MutableList = Collections.synchronizedList(ArrayList()) override val updates: Observable diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index f8295037ad..2a33c31d09 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -22,6 +22,7 @@ import com.r3corda.core.utilities.loggerFor import com.r3corda.node.internal.AbstractNode import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.keys.E2ETestKeyManagementService +import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.utilities.databaseTransaction @@ -102,7 +103,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, override fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(setOf(storage.myLegalIdentityKey)) - override fun startMessagingService() { + override fun startMessagingService(cordaRPCOps: CordaRPCOps?) { // Nothing to do } diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockServices.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockServices.kt index d640ee544b..fea06b139b 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockServices.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockServices.kt @@ -117,6 +117,10 @@ class MockAttachmentStorage : AttachmentStorage { } open class MockTransactionStorage : TransactionStorage { + override fun track(): Pair, Observable> { + return Pair(txns.values.toList(), _updatesPublisher) + } + private val txns = HashMap() private val _updatesPublisher = PublishSubject.create() @@ -153,4 +157,4 @@ fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().to props.setProperty("dataSource.user", "sa") props.setProperty("dataSource.password", "") return props -} \ No newline at end of file +}