From 236a47104f26db681f40b64c1456247f01e3e9aa Mon Sep 17 00:00:00 2001 From: "rick.parker" Date: Wed, 7 Sep 2016 10:20:40 +0100 Subject: [PATCH] Persitent network map and key service. Temporary persistence workaround for scheduler. --- .../r3corda/core/node/services/Services.kt | 8 +- .../node/utilities/JDBCHashMapTestSuite.kt | 72 ++++- .../com/r3corda/node/internal/AbstractNode.kt | 31 +- .../node/services/api/AbstractNodeService.kt | 9 +- .../events/ScheduledActivityObserver.kt | 7 + .../keys/PersistentKeyManagementService.kt | 42 +++ .../services/network/NetworkMapService.kt | 35 ++- .../network/PersistentNetworkMapService.kt | 29 ++ .../node/services/wallet/NodeWalletService.kt | 158 ++++++---- .../com/r3corda/node/utilities/JDBCHashMap.kt | 10 +- .../services/InMemoryNetworkMapServiceTest.kt | 284 ++++++++++-------- .../node/services/NodeWalletServiceTest.kt | 53 ++-- .../PersistentNetworkMapServiceTest.kt | 118 ++++++++ .../node/services/WalletWithCashTest.kt | 194 ++++++------ .../kotlin/com/r3corda/demos/TraderDemo.kt | 9 +- .../com/r3corda/testing/node/MockNode.kt | 7 +- 16 files changed, 723 insertions(+), 343 deletions(-) create mode 100644 node/src/main/kotlin/com/r3corda/node/services/keys/PersistentKeyManagementService.kt create mode 100644 node/src/main/kotlin/com/r3corda/node/services/network/PersistentNetworkMapService.kt create mode 100644 node/src/test/kotlin/com/r3corda/node/services/PersistentNetworkMapServiceTest.kt diff --git a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt index 022a56e32a..8f35c31c03 100644 --- a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt +++ b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt @@ -4,7 +4,6 @@ import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.contracts.* import com.r3corda.core.crypto.Party -import com.r3corda.core.crypto.SecureHash import com.r3corda.core.transactions.WireTransaction import java.security.KeyPair import java.security.PrivateKey @@ -29,13 +28,13 @@ val DEFAULT_SESSION_ID = 0L * * This abstract class has no references to Cash contracts. * - * [states] Holds the list of states that are *active* and *relevant*. + * [states] Holds the states that are *active* and *relevant*. * Active means they haven't been consumed yet (or we don't know about it). * Relevant means they contain at least one of our pubkeys. */ class Wallet(val states: Iterable>) { @Suppress("UNCHECKED_CAST") - inline fun statesOfType() = states.filter { it.state.data is T } as List> + inline fun statesOfType() = states.filter { it.state.data is T } as List> /** * Represents an update observed by the Wallet that will be notified to observers. Include the [StateRef]s of @@ -57,7 +56,8 @@ class Wallet(val states: Iterable>) { val previouslyConsumed = consumed val combined = Wallet.Update( previouslyConsumed + (rhs.consumed - previouslyProduced), - rhs.produced + produced.filter { it.ref !in rhs.consumed }) + // The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations. + produced.filter { it.ref !in rhs.consumed }.toSet() + rhs.produced) return combined } diff --git a/node/src/integration-test/kotlin/com/r3corda/node/utilities/JDBCHashMapTestSuite.kt b/node/src/integration-test/kotlin/com/r3corda/node/utilities/JDBCHashMapTestSuite.kt index 6cea8bcbaf..bb6fcf8fed 100644 --- a/node/src/integration-test/kotlin/com/r3corda/node/utilities/JDBCHashMapTestSuite.kt +++ b/node/src/integration-test/kotlin/com/r3corda/node/utilities/JDBCHashMapTestSuite.kt @@ -2,21 +2,25 @@ package com.r3corda.node.utilities import com.r3corda.testing.node.makeTestDataSourceProperties import junit.framework.TestSuite +import org.assertj.core.api.Assertions.assertThat import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.transactions.TransactionManager import org.junit.AfterClass import org.junit.BeforeClass +import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Suite import java.io.Closeable import java.sql.Connection +import java.util.* @RunWith(Suite::class) @Suite.SuiteClasses( JDBCHashMapTestSuite.MapLoadOnInitFalse::class, JDBCHashMapTestSuite.MapLoadOnInitTrue::class, JDBCHashMapTestSuite.SetLoadOnInitFalse::class, - JDBCHashMapTestSuite.SetLoadOnInitTrue::class) + JDBCHashMapTestSuite.SetLoadOnInitTrue::class, + JDBCHashMapTestSuite.MapCanBeReloaded::class) class JDBCHashMapTestSuite { companion object { lateinit var dataSource: Closeable @@ -144,4 +148,70 @@ class JDBCHashMapTestSuite { return set } } + + /** + * Test that the contents of a map can be reloaded from the database. + * + * If the Map reloads, then so will the Set as it just delegates. + */ + class MapCanBeReloaded { + private val ops = listOf(Triple(AddOrRemove.ADD, "A", "1"), + Triple(AddOrRemove.ADD, "B", "2"), + Triple(AddOrRemove.ADD, "C", "3"), + Triple(AddOrRemove.ADD, "D", "4"), + Triple(AddOrRemove.ADD, "E", "5"), + Triple(AddOrRemove.REMOVE, "A", "6"), + Triple(AddOrRemove.ADD, "G", "7"), + Triple(AddOrRemove.ADD, "H", "8"), + Triple(AddOrRemove.REMOVE, "D", "9"), + Triple(AddOrRemove.ADD, "C", "10")) + + private fun applyOpsToMap(map: MutableMap): MutableMap { + for (op in ops) { + if (op.first == AddOrRemove.ADD) { + map[op.second] = op.third + } else { + map.remove(op.second) + } + } + return map + } + + private val transientMapForComparison = applyOpsToMap(LinkedHashMap()) + + companion object { + lateinit var dataSource: Closeable + + @JvmStatic + @BeforeClass + fun before() { + dataSource = configureDatabase(makeTestDataSourceProperties()).first + } + + @JvmStatic + @AfterClass + fun after() { + dataSource.close() + } + } + + + @Test + fun `fill map and check content after reconstruction`() { + databaseTransaction { + val persistentMap = JDBCHashMap("the_table") + // Populate map the first time. + applyOpsToMap(persistentMap) + assertThat(persistentMap.entries).containsExactly(*transientMapForComparison.entries.toTypedArray()) + } + databaseTransaction { + val persistentMap = JDBCHashMap("the_table", loadOnInit = false) + assertThat(persistentMap.entries).containsExactly(*transientMapForComparison.entries.toTypedArray()) + } + databaseTransaction { + val persistentMap = JDBCHashMap("the_table", loadOnInit = true) + assertThat(persistentMap.entries).containsExactly(*transientMapForComparison.entries.toTypedArray()) + } + } + } } diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index adca0fd3e9..3c197f1643 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -28,13 +28,13 @@ import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.events.NodeSchedulerService import com.r3corda.node.services.events.ScheduledActivityObserver import com.r3corda.node.services.identity.InMemoryIdentityService -import com.r3corda.node.services.keys.E2ETestKeyManagementService +import com.r3corda.node.services.keys.PersistentKeyManagementService import com.r3corda.node.services.monitor.WalletMonitorService import com.r3corda.node.services.network.InMemoryNetworkMapCache -import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC import com.r3corda.node.services.network.NodeRegistration +import com.r3corda.node.services.network.PersistentNetworkMapService import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.persistence.PerFileTransactionStorage @@ -296,17 +296,17 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, "Initial network map address must indicate a node that provides a network map service" } services.networkMapCache.addNode(info) - if (networkMapService != null) { - // Only register if we are pointed at a network map service and it's not us. - // TODO: Return a future so the caller knows these operations may not have completed yet, and can monitor if needed - updateRegistration(networkMapService, AddOrRemove.ADD) - return services.networkMapCache.addMapService(net, networkMapService, true, null) - } // In the unit test environment, we may run without any network map service sometimes. - if (inNodeNetworkMapService == null) + if (networkMapService == null && inNodeNetworkMapService == null) return noNetworkMapConfigured() + else + return registerWithNetworkMap(networkMapService ?: info.address) + } + + private fun registerWithNetworkMap(networkMapServiceAddress: SingleMessageRecipient): ListenableFuture { // Register for updates, even if we're the one running the network map. - return services.networkMapCache.addMapService(net, info.address, true, null) + updateRegistration(networkMapServiceAddress, AddOrRemove.ADD) + return services.networkMapCache.addMapService(net, networkMapServiceAddress, true, null) } /** This is overriden by the mock node implementation to enable operation without any network map service */ @@ -318,8 +318,9 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, private fun updateRegistration(networkMapAddr: SingleMessageRecipient, type: AddOrRemove): ListenableFuture { // Register this node against the network - val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - val reg = NodeRegistration(info, networkMapSeq++, type, expires) + val instant = platformClock.instant() + val expires = instant + NetworkMapService.DEFAULT_EXPIRATION_PERIOD + val reg = NodeRegistration(info, instant.toEpochMilli(), type, expires) val sessionID = random63BitValue() val request = NetworkMapService.RegistrationRequest(reg.toWire(storage.myLegalIdentityKey.private), net.myAddress, sessionID) val message = net.createMessage(REGISTER_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, request.serialize().bits) @@ -333,12 +334,10 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, return future } - protected open fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(setOf(storage.myLegalIdentityKey)) + protected open fun makeKeyManagementService(): KeyManagementService = PersistentKeyManagementService(setOf(storage.myLegalIdentityKey)) open protected fun makeNetworkMapService() { - val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - val reg = NodeRegistration(info, Long.MAX_VALUE, AddOrRemove.ADD, expires) - inNodeNetworkMapService = InMemoryNetworkMapService(services, reg) + inNodeNetworkMapService = PersistentNetworkMapService(services) } open protected fun makeNotaryService(type: ServiceType): NotaryService { diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt b/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt index 452447c786..ab27742dea 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/AbstractNodeService.kt @@ -2,6 +2,7 @@ package com.r3corda.node.services.api import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.messaging.Message +import com.r3corda.core.messaging.MessageHandlerRegistration import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.serialization.SingletonSerializeAsToken @@ -36,8 +37,8 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton protected inline fun addMessageHandler(topic: String, crossinline handler: (Q) -> R, - crossinline exceptionConsumer: (Message, Exception) -> Unit) { - net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, r -> + crossinline exceptionConsumer: (Message, Exception) -> Unit): MessageHandlerRegistration { + return net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, r -> try { val request = message.data.deserialize() val response = handler(request) @@ -62,8 +63,8 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton */ protected inline fun addMessageHandler(topic: String, - crossinline handler: (Q) -> R) { - addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception }) + crossinline handler: (Q) -> R): MessageHandlerRegistration { + return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception }) } /** diff --git a/node/src/main/kotlin/com/r3corda/node/services/events/ScheduledActivityObserver.kt b/node/src/main/kotlin/com/r3corda/node/services/events/ScheduledActivityObserver.kt index be75393bed..99e567f236 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/events/ScheduledActivityObserver.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/events/ScheduledActivityObserver.kt @@ -18,6 +18,13 @@ class ScheduledActivityObserver(val services: ServiceHubInternal) { update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it) } update.produced.forEach { scheduleStateActivity(it, services.protocolLogicRefFactory) } } + + // In the short term, to get restart-able IRS demo, re-initialise from wallet state + // TODO: there's a race condition here. We need to move persistence into the scheduler but that is a bigger + // change so I want to revisit as a distinct branch/PR. + for (state in services.walletService.currentWallet.statesOfType()) { + scheduleStateActivity(state, services.protocolLogicRefFactory) + } } private fun scheduleStateActivity(produced: StateAndRef, protocolLogicRefFactory: ProtocolLogicRefFactory) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/keys/PersistentKeyManagementService.kt b/node/src/main/kotlin/com/r3corda/node/services/keys/PersistentKeyManagementService.kt new file mode 100644 index 0000000000..b6978bd476 --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/keys/PersistentKeyManagementService.kt @@ -0,0 +1,42 @@ +package com.r3corda.node.services.keys + +import com.r3corda.core.ThreadBox +import com.r3corda.core.crypto.generateKeyPair +import com.r3corda.core.node.services.KeyManagementService +import com.r3corda.core.serialization.SingletonSerializeAsToken +import com.r3corda.node.utilities.JDBCHashMap +import java.security.KeyPair +import java.security.PrivateKey +import java.security.PublicKey +import java.util.* + +/** + * A persistent re-implementation of [E2ETestKeyManagementService] to support node re-start. + * + * This is not the long-term implementation. See the list of items in the above class. + * + * This class needs database transactions to be in-flight during method calls and init. + */ +class PersistentKeyManagementService(initialKeys: Set) : SingletonSerializeAsToken(), KeyManagementService { + private class InnerState { + val keys = JDBCHashMap("key_pairs", loadOnInit = false) + } + + private val mutex = ThreadBox(InnerState()) + + init { + mutex.locked { + keys.putAll(initialKeys.associate { Pair(it.public, it.private) }) + } + } + + override val keys: Map get() = mutex.locked { HashMap(keys) } + + override fun freshKey(): KeyPair { + val keypair = generateKeyPair() + mutex.locked { + keys[keypair.public] = keypair.private + } + return keypair + } +} diff --git a/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt b/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt index 08c1ef916e..34411bf568 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt @@ -6,6 +6,7 @@ import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SignedData import com.r3corda.core.crypto.signWithECDSA +import com.r3corda.core.messaging.MessageHandlerRegistration import com.r3corda.core.messaging.MessageRecipients import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.node.NodeInfo @@ -82,13 +83,13 @@ interface NetworkMapService { } @ThreadSafe -class InMemoryNetworkMapService(services: ServiceHubInternal, home: NodeRegistration) : AbstractNetworkMapService(services) { +class InMemoryNetworkMapService(services: ServiceHubInternal) : AbstractNetworkMapService(services) { override val registeredNodes: MutableMap = ConcurrentHashMap() override val subscribers = ThreadBox(mutableMapOf()) init { - setup(home) + setup() } } @@ -99,7 +100,8 @@ class InMemoryNetworkMapService(services: ServiceHubInternal, home: NodeRegistra * subscriber clean up and is simpler to persist than the previous implementation based on a set of missing messages acks. */ @ThreadSafe -abstract class AbstractNetworkMapService(services: ServiceHubInternal) : NetworkMapService, AbstractNodeService(services) { +abstract class AbstractNetworkMapService +(services: ServiceHubInternal) : NetworkMapService, AbstractNodeService(services) { protected abstract val registeredNodes: MutableMap // Map from subscriber address, to most recently acknowledged update map version. @@ -121,35 +123,38 @@ abstract class AbstractNetworkMapService(services: ServiceHubInternal) : Network */ val maxSizeRegistrationRequestBytes = 5500 + private val handlers = ArrayList() + // Filter reduces this to the entries that add a node to the map override val nodes: List get() = registeredNodes.mapNotNull { if (it.value.reg.type == AddOrRemove.ADD) it.value.reg.node else null } - protected fun setup(home: NodeRegistration) { - // Register the local node with the service - val homeIdentity = home.node.identity - val registrationInfo = NodeRegistrationInfo(home, mapVersionIncrementAndGet()) - registeredNodes[homeIdentity] = registrationInfo - + protected fun setup() { // Register message handlers - addMessageHandler(NetworkMapService.FETCH_PROTOCOL_TOPIC, + handlers += addMessageHandler(NetworkMapService.FETCH_PROTOCOL_TOPIC, { req: NetworkMapService.FetchMapRequest -> processFetchAllRequest(req) } ) - addMessageHandler(NetworkMapService.QUERY_PROTOCOL_TOPIC, + handlers += addMessageHandler(NetworkMapService.QUERY_PROTOCOL_TOPIC, { req: NetworkMapService.QueryIdentityRequest -> processQueryRequest(req) } ) - addMessageHandler(NetworkMapService.REGISTER_PROTOCOL_TOPIC, + handlers += addMessageHandler(NetworkMapService.REGISTER_PROTOCOL_TOPIC, { req: NetworkMapService.RegistrationRequest -> processRegistrationChangeRequest(req) } ) - addMessageHandler(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, + handlers += addMessageHandler(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, { req: NetworkMapService.SubscribeRequest -> processSubscriptionRequest(req) } ) - net.addMessageHandler(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, null) { message, r -> + handlers += net.addMessageHandler(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, null) { message, r -> val req = message.data.deserialize() processAcknowledge(req) } + } - // TODO: notify subscribers of name service registration. Network service is not up, so how? + @VisibleForTesting + fun unregisterNetworkHandlers() { + for (handler in handlers) { + net.removeMessageHandler(handler) + } + handlers.clear() } private fun addSubscriber(subscriber: MessageRecipients) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/network/PersistentNetworkMapService.kt b/node/src/main/kotlin/com/r3corda/node/services/network/PersistentNetworkMapService.kt new file mode 100644 index 0000000000..bc4b294e54 --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/network/PersistentNetworkMapService.kt @@ -0,0 +1,29 @@ +package com.r3corda.node.services.network + +import com.r3corda.core.ThreadBox +import com.r3corda.core.crypto.Party +import com.r3corda.core.messaging.SingleMessageRecipient +import com.r3corda.node.services.api.ServiceHubInternal +import com.r3corda.node.utilities.JDBCHashMap +import java.util.* + +/** + * A network map service backed by a database to survive restarts of the node hosting it. + * + * Majority of the logic is inherited from [AbstractNetworkMapService]. + * + * This class needs database transactions to be in-flight during method calls and init, otherwise it will throw + * exceptions. + */ +class PersistentNetworkMapService(services: ServiceHubInternal) : AbstractNetworkMapService(services) { + + override val registeredNodes: MutableMap = Collections.synchronizedMap(JDBCHashMap("network_map_nodes", loadOnInit = true)) + + override val subscribers = ThreadBox(JDBCHashMap("network_map_subscribers", loadOnInit = true)) + + init { + // Initialise the network map version with the current highest persisted version, or zero if there are no entries. + _mapVersion.set(registeredNodes.values.map { it.mapVersion }.max() ?: 0) + setup() + } +} diff --git a/node/src/main/kotlin/com/r3corda/node/services/wallet/NodeWalletService.kt b/node/src/main/kotlin/com/r3corda/node/services/wallet/NodeWalletService.kt index ae6153eecb..83d837ae21 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/wallet/NodeWalletService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/wallet/NodeWalletService.kt @@ -1,85 +1,135 @@ package com.r3corda.node.services.wallet -import com.r3corda.core.contracts.ContractState -import com.r3corda.core.contracts.StateAndRef -import com.r3corda.core.contracts.StateRef +import com.google.common.collect.Sets +import com.r3corda.core.contracts.* import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.Wallet -import com.r3corda.core.testing.InMemoryWalletService +import com.r3corda.core.node.services.WalletService +import com.r3corda.core.serialization.SingletonSerializeAsToken +import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.trace -import com.r3corda.node.utilities.databaseTransaction -import org.jetbrains.exposed.sql.* -import org.jetbrains.exposed.sql.SchemaUtils.create +import com.r3corda.node.utilities.AbstractJDBCHashSet +import com.r3corda.node.utilities.JDBCHashedTable +import org.jetbrains.exposed.sql.ResultRow +import org.jetbrains.exposed.sql.statements.InsertStatement +import rx.Observable +import rx.subjects.PublishSubject +import java.security.PublicKey +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock /** * Currently, the node wallet service is a very simple RDBMS backed implementation. It will change significantly when * we add further functionality as the design for the wallet and wallet service matures. * + * This class needs database transactions to be in-flight during method calls and init, and will throw exceptions if + * this is not the case. + * * TODO: move query / filter criteria into the database query. * TODO: keep an audit trail with time stamps of previously unconsumed states "as of" a particular point in time. * TODO: have transaction storage do some caching. */ -class NodeWalletService(services: ServiceHub) : InMemoryWalletService(services) { +class NodeWalletService(private val services: ServiceHub) : SingletonSerializeAsToken(), WalletService { - override val log = loggerFor() - - // For now we are just tracking the current state, with no historical reporting ability. - private object UnconsumedStates : Table("vault_unconsumed_states") { - val txhash = binary("transaction_id", 32).primaryKey() - val index = integer("output_index").primaryKey() + private companion object { + val log = loggerFor() } - init { - // TODO: at some future point, we'll use some schema creation tool to deploy database artifacts if the database - // is not yet initalised to the right version of the schema. - createTablesIfNecessary() - - // Note that our wallet implementation currently does nothing with respect to attempting to apply criteria in the database. - mutex.locked { wallet = Wallet(allUnconsumedStates()) } - - // Now we need to make sure we listen to updates - updates.subscribe { recordUpdate(it) } + private object StatesSetTable : JDBCHashedTable("vault_unconsumed_states") { + val txhash = binary("transaction_id", 32) + val index = integer("output_index") } - private fun recordUpdate(update: Wallet.Update) { - val producedStateRefs = update.produced.map { it.ref } - val consumedStateRefs = update.consumed - log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } - databaseTransaction { - // Note we also remove the produced in case we are re-inserting in some form of recovery situation. - for (consumed in (consumedStateRefs + producedStateRefs)) { - UnconsumedStates.deleteWhere { - (UnconsumedStates.txhash eq consumed.txhash.bits) and (UnconsumedStates.index eq consumed.index) - } - } - for (produced in producedStateRefs) { - UnconsumedStates.insert { - it[txhash] = produced.txhash.bits - it[index] = produced.index - } - } + private val unconsumedStates = object : AbstractJDBCHashSet(StatesSetTable) { + override fun elementFromRow(it: ResultRow): StateRef = StateRef(SecureHash.SHA256(it[table.txhash]), it[table.index]) + + override fun addElementToInsert(insert: InsertStatement, entry: StateRef, finalizables: MutableList<() -> Unit>) { + insert[table.txhash] = entry.txhash.bits + insert[table.index] = entry.index } } - private fun createTablesIfNecessary() { - log.trace { "Creating database tables if necessary." } - databaseTransaction { - create(UnconsumedStates) + protected val mutex = ReentrantLock() + + override val currentWallet: Wallet get() = mutex.withLock { Wallet(allUnconsumedStates()) } + + private val _updatesPublisher = PublishSubject.create() + + override val updates: Observable + get() = _updatesPublisher + + /** + * Returns a snapshot of the heads of LinearStates. + * + * TODO: Represent this using an actual JDBCHashMap or look at vault design further. + */ + override val linearHeads: Map> + get() = currentWallet.states.filterStatesOfType().associateBy { it.state.data.linearId }.mapValues { it.value } + + override fun notifyAll(txns: Iterable): Wallet { + val ourKeys = services.keyManagementService.keys.keys + val netDelta = txns.fold(Wallet.NoUpdate) { netDelta, txn -> netDelta + makeUpdate(txn, netDelta, ourKeys) } + if (netDelta != Wallet.NoUpdate) { + mutex.withLock { + recordUpdate(netDelta) + } + _updatesPublisher.onNext(netDelta) } + return currentWallet + } + + private fun makeUpdate(tx: WireTransaction, netDelta: Wallet.Update, ourKeys: Set): Wallet.Update { + val ourNewStates = tx.outputs. + filter { isRelevant(it.data, ourKeys) }. + map { tx.outRef(it.data) } + + // Now calculate the states that are being spent by this transaction. + val consumed = tx.inputs.toHashSet() + // We use Guava union here as it's lazy for contains() which is how retainAll() is implemented. + // i.e. retainAll() iterates over consumed, checking contains() on the parameter. Sets.union() does not physically create + // a new collection and instead contains() just checks the contains() of both parameters, and so we don't end up + // iterating over all (a potentially very large) unconsumedStates at any point. + consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates)) + + // Is transaction irrelevant? + if (consumed.isEmpty() && ourNewStates.isEmpty()) { + log.trace { "tx ${tx.id} was irrelevant to this wallet, ignoring" } + return Wallet.NoUpdate + } + + return Wallet.Update(consumed, ourNewStates.toHashSet()) + } + + private fun isRelevant(state: ContractState, ourKeys: Set): Boolean { + return if (state is OwnableState) { + state.owner in ourKeys + } else if (state is LinearState) { + // It's potentially of interest to the wallet + state.isRelevant(ourKeys) + } else { + false + } + } + + private fun recordUpdate(update: Wallet.Update): Wallet.Update { + if (update != Wallet.NoUpdate) { + val producedStateRefs = update.produced.map { it.ref } + val consumedStateRefs = update.consumed + log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } + unconsumedStates.removeAll(consumedStateRefs) + unconsumedStates.addAll(producedStateRefs) + } + return update } private fun allUnconsumedStates(): Iterable> { // Order by txhash for if and when transaction storage has some caching. - // Map to StateRef and then to StateAndRef. - return databaseTransaction { - UnconsumedStates.selectAll().orderBy(UnconsumedStates.txhash) - .map { StateRef(SecureHash.SHA256(it[UnconsumedStates.txhash]), it[UnconsumedStates.index]) } - .map { - val storedTx = services.storageService.validatedTransactions.getTransaction(it.txhash) ?: throw Error("Found transaction hash ${it.txhash} in unconsumed contract states that is not in transaction storage.") - StateAndRef(storedTx.tx.outputs[it.index], it) - } - } + // Map to StateRef and then to StateAndRef. Use Sequence to avoid conversion to ArrayList that Iterable.map() performs. + return unconsumedStates.asSequence().map { + val storedTx = services.storageService.validatedTransactions.getTransaction(it.txhash) ?: throw Error("Found transaction hash ${it.txhash} in unconsumed contract states that is not in transaction storage.") + StateAndRef(storedTx.tx.outputs[it.index], it) + }.asIterable() } } \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/utilities/JDBCHashMap.kt b/node/src/main/kotlin/com/r3corda/node/utilities/JDBCHashMap.kt index b2443d806b..21d08888ff 100644 --- a/node/src/main/kotlin/com/r3corda/node/utilities/JDBCHashMap.kt +++ b/node/src/main/kotlin/com/r3corda/node/utilities/JDBCHashMap.kt @@ -336,6 +336,7 @@ abstract class AbstractJDBCHashMap(val ta override fun put(key: K, value: V): V? { var oldValue: V? = null + var oldSeqNo: Int? = null getBucket(key) buckets.compute(key.hashCode()) { hashCode, list -> val newList = list ?: newBucket() @@ -344,12 +345,13 @@ abstract class AbstractJDBCHashMap(val ta val entry = iterator.next() if (entry.key == key) { oldValue = entry.value + oldSeqNo = entry.seqNo iterator.remove() deleteRecord(entry) break } } - val seqNo = addRecord(key, value) + val seqNo = addRecord(key, value, oldSeqNo) val newEntry = NotReallyMutableEntry(key, value, seqNo) newList.add(newEntry) newList @@ -450,7 +452,7 @@ abstract class AbstractJDBCHashMap(val ta } } - private fun addRecord(key: K, value: V): Int { + private fun addRecord(key: K, value: V, oldSeqNo: Int?): Int { val finalizables = mutableListOf<() -> Unit>() try { return table.insert { @@ -458,6 +460,10 @@ abstract class AbstractJDBCHashMap(val ta val entry = SimpleEntry(key, value) addKeyToInsert(it, entry, finalizables) addValueToInsert(it, entry, finalizables) + if (oldSeqNo != null) { + it[seqNo] = oldSeqNo + it.generatedKey = oldSeqNo + } } get table.seqNo } finally { finalizables.forEach { it() } diff --git a/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt index aa17f0ce9a..b5dce96739 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt @@ -1,25 +1,14 @@ package com.r3corda.node.services import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.map -import com.r3corda.core.messaging.TopicSession -import com.r3corda.core.messaging.runOnNextMessage -import com.r3corda.core.messaging.send import com.r3corda.core.random63BitValue -import com.r3corda.core.serialization.deserialize +import com.r3corda.node.services.network.AbstractNetworkMapService import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.NetworkMapService -import com.r3corda.node.services.network.NetworkMapService.* -import com.r3corda.node.services.network.NetworkMapService.Companion.FETCH_PROTOCOL_TOPIC -import com.r3corda.node.services.network.NetworkMapService.Companion.PUSH_ACK_PROTOCOL_TOPIC -import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC -import com.r3corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_PROTOCOL_TOPIC import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.utilities.AddOrRemove -import com.r3corda.protocols.ServiceRequestMessage import com.r3corda.testing.node.MockNetwork -import com.r3corda.testing.node.MockNetwork.MockNode import org.junit.Before import org.junit.Test import java.security.PrivateKey @@ -30,7 +19,164 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue -class InMemoryNetworkMapServiceTest { +/** + * Abstracted out test logic to be re-used by [PersistentNetworkMapServiceTest]. + */ +abstract class AbstractNetworkMapServiceTest { + + protected fun success(mapServiceNode: MockNetwork.MockNode, + registerNode: MockNetwork.MockNode, + service: () -> AbstractNetworkMapService, + swizzle: () -> Unit) { + // For persistent service, switch out the implementation for a newly instantiated one so we can check the state is preserved. + swizzle() + + // Confirm the service contains no nodes as own node only registered if network is run. + assertEquals(0, service().nodes.count()) + assertNull(service().processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) + + // Register the new node + val instant = Instant.now() + val expires = instant + NetworkMapService.DEFAULT_EXPIRATION_PERIOD + val nodeKey = registerNode.storage.myLegalIdentityKey + val addChange = NodeRegistration(registerNode.info, instant.toEpochMilli(), AddOrRemove.ADD, expires) + val addWireChange = addChange.toWire(nodeKey.private) + service().processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) + swizzle() + + assertEquals(1, service().nodes.count()) + assertEquals(registerNode.info, service().processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) + + // Re-registering should be a no-op + service().processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) + swizzle() + + assertEquals(1, service().nodes.count()) + + // Confirm that de-registering the node succeeds and drops it from the node lists + val removeChange = NodeRegistration(registerNode.info, instant.toEpochMilli()+1, AddOrRemove.REMOVE, expires) + val removeWireChange = removeChange.toWire(nodeKey.private) + assert(service().processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) + swizzle() + + assertNull(service().processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) + swizzle() + + // Trying to de-register a node that doesn't exist should fail + assert(!service().processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) + } + + protected fun `success with network`(network: MockNetwork, + mapServiceNode: MockNetwork.MockNode, + registerNode: MockNetwork.MockNode, + swizzle: () -> Unit) { + // For persistent service, switch out the implementation for a newly instantiated one so we can check the state is preserved. + swizzle() + + // Confirm all nodes have registered themselves + network.runNetwork() + var fetchPsm = fetchMap(registerNode, mapServiceNode, false) + network.runNetwork() + assertEquals(2, fetchPsm.get()?.count()) + + // Forcibly deregister the second node + val nodeKey = registerNode.storage.myLegalIdentityKey + val instant = Instant.now() + val expires = instant + NetworkMapService.DEFAULT_EXPIRATION_PERIOD + val reg = NodeRegistration(registerNode.info, instant.toEpochMilli()+1, AddOrRemove.REMOVE, expires) + val registerPsm = registration(registerNode, mapServiceNode, reg, nodeKey.private) + network.runNetwork() + assertTrue(registerPsm.get().success) + + swizzle() + + // Now only map service node should be registered + fetchPsm = fetchMap(registerNode, mapServiceNode, false) + network.runNetwork() + assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single()) + } + + protected fun `subscribe with network`(network: MockNetwork, + mapServiceNode: MockNetwork.MockNode, + registerNode: MockNetwork.MockNode, + service: () -> AbstractNetworkMapService, + swizzle: () -> Unit) { + // For persistent service, switch out the implementation for a newly instantiated one so we can check the state is preserved. + swizzle() + + // Test subscribing to updates + network.runNetwork() + val subscribePsm = subscribe(registerNode, mapServiceNode, true) + network.runNetwork() + subscribePsm.get() + + swizzle() + + val startingMapVersion = service().mapVersion + + // Check the unacknowledged count is zero + assertEquals(0, service().getUnacknowledgedCount(registerNode.info.address, startingMapVersion)) + + // Fire off an update + val nodeKey = registerNode.storage.myLegalIdentityKey + var seq = 0L + val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD + var reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) + var wireReg = reg.toWire(nodeKey.private) + service().notifySubscribers(wireReg, startingMapVersion + 1) + + swizzle() + + // Check the unacknowledged count is one + assertEquals(1, service().getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) + + // Send in an acknowledgment and verify the count goes down + updateAcknowlege(registerNode, mapServiceNode, startingMapVersion + 1) + network.runNetwork() + + swizzle() + + assertEquals(0, service().getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) + + // Intentionally fill the pending acknowledgements to verify it doesn't drop subscribers before the limit + // is hit. On the last iteration overflow the pending list, and check the node is unsubscribed + for (i in 0..service().maxUnacknowledgedUpdates) { + reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) + wireReg = reg.toWire(nodeKey.private) + service().notifySubscribers(wireReg, i + startingMapVersion + 2) + + swizzle() + + if (i < service().maxUnacknowledgedUpdates) { + assertEquals(i + 1, service().getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) + } else { + assertNull(service().getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) + } + } + } + + private fun registration(registerNode: MockNetwork.MockNode, mapServiceNode: MockNetwork.MockNode, reg: NodeRegistration, privateKey: PrivateKey): ListenableFuture { + val req = NetworkMapService.RegistrationRequest(reg.toWire(privateKey), registerNode.services.networkService.myAddress, random63BitValue()) + return registerNode.sendAndReceive(NetworkMapService.REGISTER_PROTOCOL_TOPIC, mapServiceNode, req) + } + + private fun subscribe(registerNode: MockNetwork.MockNode, mapServiceNode: MockNetwork.MockNode, subscribe: Boolean): ListenableFuture { + val req = NetworkMapService.SubscribeRequest(subscribe, registerNode.services.networkService.myAddress, random63BitValue()) + return registerNode.sendAndReceive(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, mapServiceNode, req) + } + + private fun updateAcknowlege(registerNode: MockNetwork.MockNode, mapServiceNode: MockNetwork.MockNode, mapVersion: Int) { + val req = NetworkMapService.UpdateAcknowledge(mapVersion, registerNode.services.networkService.myAddress) + registerNode.send(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, mapServiceNode, req) + } + + private fun fetchMap(registerNode: MockNetwork.MockNode, mapServiceNode: MockNetwork.MockNode, subscribe: Boolean, ifChangedSinceVersion: Int? = null): Future?> { + val req = NetworkMapService.FetchMapRequest(subscribe, ifChangedSinceVersion, registerNode.services.networkService.myAddress, random63BitValue()) + return registerNode.sendAndReceive(NetworkMapService.FETCH_PROTOCOL_TOPIC, mapServiceNode, req).map { it.nodes } + } +} + +class InMemoryNetworkMapServiceTest : AbstractNetworkMapServiceTest() { lateinit var network: MockNetwork @Before @@ -45,33 +191,7 @@ class InMemoryNetworkMapServiceTest { fun success() { val (mapServiceNode, registerNode) = network.createTwoNodes() val service = mapServiceNode.inNodeNetworkMapService!! as InMemoryNetworkMapService - - // Confirm the service contains only its own node - assertEquals(1, service.nodes.count()) - assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) - - // Register the second node - var seq = 1L - val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - val nodeKey = registerNode.storage.myLegalIdentityKey - val addChange = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) - val addWireChange = addChange.toWire(nodeKey.private) - service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) - assertEquals(2, service.nodes.count()) - assertEquals(mapServiceNode.info, service.processQueryRequest(QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) - - // Re-registering should be a no-op - service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE)) - assertEquals(2, service.nodes.count()) - - // Confirm that de-registering the node succeeds and drops it from the node lists - val removeChange = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires) - val removeWireChange = removeChange.toWire(nodeKey.private) - assert(service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) - assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node) - - // Trying to de-register a node that doesn't exist should fail - assert(!service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) + success(mapServiceNode, registerNode, { service }, { }) } @Test @@ -80,93 +200,13 @@ class InMemoryNetworkMapServiceTest { // Confirm there's a network map service on node 0 assertNotNull(mapServiceNode.inNodeNetworkMapService) - - // Confirm all nodes have registered themselves - network.runNetwork() - var fetchPsm = fetchMap(registerNode, mapServiceNode, false) - network.runNetwork() - assertEquals(2, fetchPsm.get()?.count()) - - // Forcibly deregister the second node - val nodeKey = registerNode.storage.myLegalIdentityKey - val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - val seq = 2L - val reg = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires) - val registerPsm = registration(registerNode, mapServiceNode, reg, nodeKey.private) - network.runNetwork() - assertTrue(registerPsm.get().success) - - // Now only map service node should be registered - fetchPsm = fetchMap(registerNode, mapServiceNode, false) - network.runNetwork() - assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single()) + `success with network`(network, mapServiceNode, registerNode, { }) } @Test fun `subscribe with network`() { val (mapServiceNode, registerNode) = network.createTwoNodes() val service = (mapServiceNode.inNodeNetworkMapService as InMemoryNetworkMapService) - - // Test subscribing to updates - network.runNetwork() - val subscribePsm = subscribe(registerNode, mapServiceNode, true) - network.runNetwork() - subscribePsm.get() - - val startingMapVersion = service.mapVersion - - // Check the unacknowledged count is zero - assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion)) - - // Fire off an update - val nodeKey = registerNode.storage.myLegalIdentityKey - var seq = 0L - val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - var reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) - var wireReg = reg.toWire(nodeKey.private) - service.notifySubscribers(wireReg, startingMapVersion + 1) - - // Check the unacknowledged count is one - assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) - - // Send in an acknowledgment and verify the count goes down - updateAcknowlege(registerNode, mapServiceNode, startingMapVersion + 1) - network.runNetwork() - - assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) - - // Intentionally fill the pending acknowledgements to verify it doesn't drop subscribers before the limit - // is hit. On the last iteration overflow the pending list, and check the node is unsubscribed - for (i in 0..service.maxUnacknowledgedUpdates) { - reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) - wireReg = reg.toWire(nodeKey.private) - service.notifySubscribers(wireReg, i + startingMapVersion + 2) - if (i < service.maxUnacknowledgedUpdates) { - assertEquals(i + 1, service.getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) - } else { - assertNull(service.getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) - } - } + `subscribe with network`(network, mapServiceNode, registerNode, { service }, { }) } - - private fun registration(registerNode: MockNode, mapServiceNode: MockNode, reg: NodeRegistration, privateKey: PrivateKey): ListenableFuture { - val req = RegistrationRequest(reg.toWire(privateKey), registerNode.services.networkService.myAddress, random63BitValue()) - return registerNode.sendAndReceive(REGISTER_PROTOCOL_TOPIC, mapServiceNode, req) - } - - private fun subscribe(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean): ListenableFuture { - val req = SubscribeRequest(subscribe, registerNode.services.networkService.myAddress, random63BitValue()) - return registerNode.sendAndReceive(SUBSCRIPTION_PROTOCOL_TOPIC, mapServiceNode, req) - } - - private fun updateAcknowlege(registerNode: MockNode, mapServiceNode: MockNode, mapVersion: Int) { - val req = UpdateAcknowledge(mapVersion, registerNode.services.networkService.myAddress) - registerNode.send(PUSH_ACK_PROTOCOL_TOPIC, mapServiceNode, req) - } - - private fun fetchMap(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean, ifChangedSinceVersion: Int? = null): Future?> { - val req = FetchMapRequest(subscribe, ifChangedSinceVersion, registerNode.services.networkService.myAddress, random63BitValue()) - return registerNode.sendAndReceive(FETCH_PROTOCOL_TOPIC, mapServiceNode, req).map { it.nodes } - } - } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/NodeWalletServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/NodeWalletServiceTest.kt index 779c5eb4df..0475307a78 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NodeWalletServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NodeWalletServiceTest.kt @@ -2,15 +2,16 @@ package com.r3corda.node.services import com.r3corda.contracts.testing.fillWithSomeTestCash import com.r3corda.core.contracts.DOLLARS -import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.node.services.WalletService -import com.r3corda.testing.node.MockServices -import com.r3corda.testing.node.makeTestDataSourceProperties +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.LogHelper import com.r3corda.node.services.wallet.NodeWalletService import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.node.MockServices +import com.r3corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before @@ -35,37 +36,39 @@ class NodeWalletServiceTest { @Test fun `states not local to instance`() { - val services1 = object : MockServices() { - override val walletService: WalletService = NodeWalletService(this) + databaseTransaction { + val services1 = object : MockServices() { + override val walletService: WalletService = NodeWalletService(this) - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - storageService.validatedTransactions.addTransaction(stx) - walletService.notify(stx.tx) + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + storageService.validatedTransactions.addTransaction(stx) + walletService.notify(stx.tx) + } } } - } - services1.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services1.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - val w1 = services1.walletService.currentWallet - assertThat(w1.states).hasSize(3) + val w1 = services1.walletService.currentWallet + assertThat(w1.states).hasSize(3) - val originalStorage = services1.storageService - val services2 = object : MockServices() { - override val walletService: WalletService = NodeWalletService(this) + val originalStorage = services1.storageService + val services2 = object : MockServices() { + override val walletService: WalletService = NodeWalletService(this) - // We need to be able to find the same transactions as before, too. - override val storageService: TxWritableStorageService get() = originalStorage + // We need to be able to find the same transactions as before, too. + override val storageService: TxWritableStorageService get() = originalStorage - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - storageService.validatedTransactions.addTransaction(stx) - walletService.notify(stx.tx) + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + storageService.validatedTransactions.addTransaction(stx) + walletService.notify(stx.tx) + } } } - } - val w2 = services2.walletService.currentWallet - assertThat(w2.states).hasSize(3) + val w2 = services2.walletService.currentWallet + assertThat(w2.states).hasSize(3) + } } } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/PersistentNetworkMapServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/PersistentNetworkMapServiceTest.kt new file mode 100644 index 0000000000..1d96815e31 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/services/PersistentNetworkMapServiceTest.kt @@ -0,0 +1,118 @@ +package com.r3corda.node.services + +import com.r3corda.core.messaging.SingleMessageRecipient +import com.r3corda.core.node.NodeInfo +import com.r3corda.core.node.services.ServiceType +import com.r3corda.node.services.api.ServiceHubInternal +import com.r3corda.node.services.config.NodeConfiguration +import com.r3corda.node.services.network.AbstractNetworkMapService +import com.r3corda.node.services.network.InMemoryNetworkMapService +import com.r3corda.node.services.network.NetworkMapService +import com.r3corda.node.services.network.PersistentNetworkMapService +import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.node.MockNetwork +import com.r3corda.testing.node.makeTestDataSourceProperties +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.Closeable +import java.nio.file.Path +import java.security.KeyPair + +/** + * This class mirrors [InMemoryNetworkMapServiceTest] but switches in a [PersistentNetworkMapService] and + * repeatedly replaces it with new instances to check that the service correctly restores the most recent state. + */ +class PersistentNetworkMapServiceTest : AbstractNetworkMapServiceTest() { + lateinit var network: MockNetwork + lateinit var dataSource: Closeable + + @Before + fun setup() { + network = MockNetwork() + } + + @After + fun tearDown() { + dataSource.close() + } + + /** + * We use a special [NetworkMapService] that allows us to switch in a new instance at any time to check that the + * state within it is correctly restored. + */ + private class SwizzleNetworkMapService(services: ServiceHubInternal) : NetworkMapService { + var delegate: AbstractNetworkMapService = InMemoryNetworkMapService(services) + + override val nodes: List + get() = delegate.nodes + + fun swizzle() { + delegate.unregisterNetworkHandlers() + delegate=makeNetworkMapService(delegate.services) + } + + private fun makeNetworkMapService(services: ServiceHubInternal): AbstractNetworkMapService { + return PersistentNetworkMapService(services) + } + } + + private object NodeFactory : MockNetwork.Factory { + override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, + advertisedServices: Set, id: Int, keyPair: KeyPair?): MockNetwork.MockNode { + return object : MockNetwork.MockNode(dir, config, network, networkMapAddr, advertisedServices, id, keyPair) { + + override fun makeNetworkMapService() { + inNodeNetworkMapService = SwizzleNetworkMapService(services) + } + } + } + } + + /** + * Perform basic tests of registering, de-registering and fetching the full network map. + */ + @Test + fun success() { + val (mapServiceNode, registerNode) = network.createTwoNodes(NodeFactory) + val service = mapServiceNode.inNodeNetworkMapService!! as SwizzleNetworkMapService + + // We have to set this up after the non-persistent nodes as they install a dummy transaction manager. + dataSource = configureDatabase(makeTestDataSourceProperties()).first + + databaseTransaction { + success(mapServiceNode, registerNode, { service.delegate }, {service.swizzle()}) + } + } + + @Test + fun `success with network`() { + val (mapServiceNode, registerNode) = network.createTwoNodes(NodeFactory) + + // Confirm there's a network map service on node 0 + val service = mapServiceNode.inNodeNetworkMapService!! as SwizzleNetworkMapService + + // We have to set this up after the non-persistent nodes as they install a dummy transaction manager. + dataSource = configureDatabase(makeTestDataSourceProperties()).first + + databaseTransaction { + `success with network`(network, mapServiceNode, registerNode, { service.swizzle() }) + } + } + + @Test + fun `subscribe with network`() { + val (mapServiceNode, registerNode) = network.createTwoNodes(NodeFactory) + + // Confirm there's a network map service on node 0 + val service = mapServiceNode.inNodeNetworkMapService!! as SwizzleNetworkMapService + + // We have to set this up after the non-persistent nodes as they install a dummy transaction manager. + dataSource = configureDatabase(makeTestDataSourceProperties()).first + + databaseTransaction { + `subscribe with network`(network, mapServiceNode, registerNode, { service.delegate }, { service.swizzle() }) + } + } +} diff --git a/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt b/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt index 5bafcd85e3..06f57e9a73 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt @@ -12,10 +12,10 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.LogHelper import com.r3corda.node.services.wallet.NodeWalletService import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.* import com.r3corda.testing.node.MockServices import com.r3corda.testing.node.makeTestDataSourceProperties -import com.r3corda.testing.DummyLinearContract -import com.r3corda.testing.* import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After import org.junit.Before @@ -36,13 +36,16 @@ class WalletWithCashTest { fun setUp() { LogHelper.setLevel(NodeWalletService::class) dataSource = configureDatabase(makeTestDataSourceProperties()).first - services = object : MockServices() { - override val walletService: WalletService = NodeWalletService(this) + databaseTransaction { + services = object : MockServices() { + override val walletService: WalletService = NodeWalletService(this) - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - storageService.validatedTransactions.addTransaction(stx) - walletService.notify(stx.tx) + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + storageService.validatedTransactions.addTransaction(stx) + } + // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. + walletService.notifyAll(txs.map { it.tx }) } } } @@ -56,102 +59,111 @@ class WalletWithCashTest { @Test fun splits() { - // Fix the PRNG so that we get the same splits every time. - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + databaseTransaction { + // Fix the PRNG so that we get the same splits every time. + services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - val w = wallet.currentWallet - assertEquals(3, w.states.toList().size) + val w = wallet.currentWallet + assertEquals(3, w.states.toList().size) - val state = w.states.toList()[0].state.data as Cash.State - assertEquals(30.45.DOLLARS `issued by` DUMMY_CASH_ISSUER, state.amount) - assertEquals(services.key.public, state.owner) + val state = w.states.toList()[0].state.data as Cash.State + assertEquals(30.45.DOLLARS `issued by` DUMMY_CASH_ISSUER, state.amount) + assertEquals(services.key.public, state.owner) - assertEquals(34.70.DOLLARS `issued by` DUMMY_CASH_ISSUER, (w.states.toList()[2].state.data as Cash.State).amount) - assertEquals(34.85.DOLLARS `issued by` DUMMY_CASH_ISSUER, (w.states.toList()[1].state.data as Cash.State).amount) - } - - @Test - fun basics() { - // A tx that sends us money. - val freshKey = services.keyManagementService.freshKey() - val usefulTX = TransactionType.General.Builder(null).apply { - Cash().generateIssue(this, 100.DOLLARS `issued by` MEGA_CORP.ref(1), freshKey.public, DUMMY_NOTARY) - signWith(MEGA_CORP_KEY) - }.toSignedTransaction() - val myOutput = usefulTX.toLedgerTransaction(services).outRef(0) - - // A tx that spends our money. - val spendTX = TransactionType.General.Builder(DUMMY_NOTARY).apply { - Cash().generateSpend(this, 80.DOLLARS, BOB_PUBKEY, listOf(myOutput)) - signWith(freshKey) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() - - // A tx that doesn't send us anything. - val irrelevantTX = TransactionType.General.Builder(DUMMY_NOTARY).apply { - Cash().generateIssue(this, 100.DOLLARS `issued by` MEGA_CORP.ref(1), BOB_KEY.public, DUMMY_NOTARY) - signWith(MEGA_CORP_KEY) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() - - assertNull(wallet.currentWallet.cashBalances[USD]) - wallet.notify(usefulTX.tx) - assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD]) - wallet.notify(irrelevantTX.tx) - assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD]) - wallet.notify(spendTX.tx) - assertEquals(20.DOLLARS, wallet.currentWallet.cashBalances[USD]) - - // TODO: Flesh out these tests as needed. - } - - - @Test - fun branchingLinearStatesFailsToVerify() { - val freshKey = services.keyManagementService.freshKey() - val linearId = UniqueIdentifier() - - // Issue a linear state - val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { - addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) - addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) - signWith(freshKey) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() - - assertThatThrownBy { - dummyIssue.toLedgerTransaction(services).verify() + assertEquals(34.70.DOLLARS `issued by` DUMMY_CASH_ISSUER, (w.states.toList()[2].state.data as Cash.State).amount) + assertEquals(34.85.DOLLARS `issued by` DUMMY_CASH_ISSUER, (w.states.toList()[1].state.data as Cash.State).amount) } } @Test - fun sequencingLinearStatesWorks() { - val freshKey = services.keyManagementService.freshKey() + fun `issue and spend total correctly and irrelevant ignored`() { + databaseTransaction { + // A tx that sends us money. + val freshKey = services.keyManagementService.freshKey() + val usefulTX = TransactionType.General.Builder(null).apply { + Cash().generateIssue(this, 100.DOLLARS `issued by` MEGA_CORP.ref(1), freshKey.public, DUMMY_NOTARY) + signWith(MEGA_CORP_KEY) + }.toSignedTransaction() + val myOutput = usefulTX.toLedgerTransaction(services).outRef(0) - val linearId = UniqueIdentifier() + // A tx that spends our money. + val spendTX = TransactionType.General.Builder(DUMMY_NOTARY).apply { + Cash().generateSpend(this, 80.DOLLARS, BOB_PUBKEY, listOf(myOutput)) + signWith(freshKey) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() - // Issue a linear state - val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { - addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) - signWith(freshKey) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + // A tx that doesn't send us anything. + val irrelevantTX = TransactionType.General.Builder(DUMMY_NOTARY).apply { + Cash().generateIssue(this, 100.DOLLARS `issued by` MEGA_CORP.ref(1), BOB_KEY.public, DUMMY_NOTARY) + signWith(MEGA_CORP_KEY) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() - dummyIssue.toLedgerTransaction(services).verify() + assertNull(wallet.currentWallet.cashBalances[USD]) + services.recordTransactions(usefulTX) + assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD]) + services.recordTransactions(irrelevantTX) + assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD]) + services.recordTransactions(spendTX) - wallet.notify(dummyIssue.tx) - assertEquals(1, wallet.currentWallet.states.toList().size) + assertEquals(20.DOLLARS, wallet.currentWallet.cashBalances[USD]) - // Move the same state - val dummyMove = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { - addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) - addInputState(dummyIssue.tx.outRef(0)) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + // TODO: Flesh out these tests as needed. + } + } - dummyIssue.toLedgerTransaction(services).verify() - wallet.notify(dummyMove.tx) - assertEquals(1, wallet.currentWallet.states.toList().size) + @Test + fun `branching LinearStates fails to verify`() { + databaseTransaction { + val freshKey = services.keyManagementService.freshKey() + val linearId = UniqueIdentifier() + + // Issue a linear state + val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) + addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) + signWith(freshKey) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() + + assertThatThrownBy { + dummyIssue.toLedgerTransaction(services).verify() + } + } + } + + @Test + fun `sequencing LinearStates works`() { + databaseTransaction { + val freshKey = services.keyManagementService.freshKey() + + val linearId = UniqueIdentifier() + + // Issue a linear state + val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) + signWith(freshKey) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() + + dummyIssue.toLedgerTransaction(services).verify() + + services.recordTransactions(dummyIssue) + assertEquals(1, wallet.currentWallet.states.toList().size) + + // Move the same state + val dummyMove = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshKey.public))) + addInputState(dummyIssue.tx.outRef(0)) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() + + dummyIssue.toLedgerTransaction(services).verify() + + services.recordTransactions(dummyMove) + assertEquals(1, wallet.currentWallet.states.toList().size) + } } } diff --git a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt index 7808161554..e2bdbb3a18 100644 --- a/src/main/kotlin/com/r3corda/demos/TraderDemo.kt +++ b/src/main/kotlin/com/r3corda/demos/TraderDemo.kt @@ -27,6 +27,7 @@ import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.transactions.SimpleNotaryService +import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.HandshakeMessage import com.r3corda.protocols.NotaryProtocol import com.r3corda.protocols.TwoPartyTradeProtocol @@ -198,9 +199,11 @@ private fun runBuyer(node: Node, amount: Amount) { // Self issue some cash. // // TODO: At some point this demo should be extended to have a central bank node. - node.services.fillWithSomeTestCash(300000.DOLLARS, - outputNotary = node.info.identity, // In this demo, the buyer and notary are the same. - ownedBy = node.services.keyManagementService.freshKey().public) + databaseTransaction { + node.services.fillWithSomeTestCash(300000.DOLLARS, + outputNotary = node.info.identity, // In this demo, the buyer and notary are the same. + ownedBy = node.services.keyManagementService.freshKey().public) + } // Wait around until a node asks to start a trade with us. In a real system, this part would happen out of band // via some other system like an exchange or maybe even a manual messaging system like Bloomberg. But for the diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index 044a849655..ec2ebe8d43 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -21,10 +21,7 @@ import com.r3corda.core.utilities.loggerFor import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.keys.E2ETestKeyManagementService import com.r3corda.node.services.network.InMemoryNetworkMapService -import com.r3corda.node.services.network.NetworkMapService -import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.services.transactions.InMemoryUniquenessProvider -import com.r3corda.node.utilities.AddOrRemove import com.r3corda.protocols.ServiceRequestMessage import org.jetbrains.exposed.sql.transactions.TransactionManager import org.slf4j.Logger @@ -114,9 +111,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, } override fun makeNetworkMapService() { - val expires = platformClock.instant() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD - val reg = NodeRegistration(info, Long.MAX_VALUE, AddOrRemove.ADD, expires) - inNodeNetworkMapService = InMemoryNetworkMapService(services, reg) + inNodeNetworkMapService = InMemoryNetworkMapService(services) } override fun generateKeyPair(): KeyPair = keyPair ?: super.generateKeyPair()