From 7cb4cbcad40836517a0dcdc529e93c3d56e7b05a Mon Sep 17 00:00:00 2001 From: Andrius Dagys Date: Wed, 14 Dec 2016 15:25:43 +0000 Subject: [PATCH] Made vault updates contain full StateAndRef in the consumed set (instead of just StateRef). This allows subscribers to check whether the update contains relevant states. Cash balances are now calculated by keeping only the aggregate values (it no longer needs to iterate through all states in the vault). --- .../corda/client/model/ContractStateModel.kt | 8 +-- .../net/corda/core/node/services/Services.kt | 21 +++---- .../net/corda/core/node/VaultUpdateTests.kt | 22 +++---- .../node/services/api/ServiceHubInternal.kt | 2 +- .../events/ScheduledActivityObserver.kt | 2 +- .../node/services/vault/NodeVaultService.kt | 57 +++++++++++++++++-- 6 files changed, 75 insertions(+), 37 deletions(-) diff --git a/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt b/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt index affa3aa7b1..33ec550322 100644 --- a/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt +++ b/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt @@ -8,13 +8,12 @@ import net.corda.client.fxutils.map import net.corda.contracts.asset.Cash import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.StateRef import net.corda.core.node.services.Vault import rx.Observable data class Diff( val added: Collection>, - val removed: Collection + val removed: Collection> ) /** @@ -27,11 +26,10 @@ class ContractStateModel { Diff(it.produced, it.consumed) } private val cashStatesDiff: Observable> = contractStatesDiff.map { - // We can't filter removed hashes here as we don't have type info - Diff(it.added.filterCashStateAndRefs(), it.removed) + Diff(it.added.filterCashStateAndRefs(), it.removed.filterCashStateAndRefs()) } val cashStates: ObservableList> = cashStatesDiff.fold(FXCollections.observableArrayList()) { list, statesDiff -> - list.removeIf { it.ref in statesDiff.removed } + list.removeIf { it in statesDiff.removed } list.addAll(statesDiff.added) } diff --git a/core/src/main/kotlin/net/corda/core/node/services/Services.kt b/core/src/main/kotlin/net/corda/core/node/services/Services.kt index 5f5215f1ac..5103ed5a9c 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/Services.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/Services.kt @@ -50,7 +50,10 @@ class Vault(val states: Iterable>) { * If the vault observes multiple transactions simultaneously, where some transactions consume the outputs of some of the * other transactions observed, then the changes are observed "net" of those. */ - data class Update(val consumed: Set, val produced: Set>) { + data class Update(val consumed: Set>, val produced: Set>) { + /** Checks whether the update contains a state of the specified type. */ + inline fun containsType() = consumed.any { it.state.data is T } || produced.any { it.state.data is T } + /** * Combine two updates into a single update with the combined inputs and outputs of the two updates but net * any outputs of the left-hand-side (this) that are consumed by the inputs of the right-hand-side (rhs). @@ -58,12 +61,10 @@ class Vault(val states: Iterable>) { * i.e. the net effect in terms of state live-ness of receiving the combined update is the same as receiving this followed by rhs. */ operator fun plus(rhs: Update): Update { - val previouslyProduced = produced.map { it.ref } - val previouslyConsumed = consumed val combined = Vault.Update( - previouslyConsumed + (rhs.consumed - previouslyProduced), + consumed + (rhs.consumed - produced), // The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations. - produced.filter { it.ref !in rhs.consumed }.toSet() + rhs.produced) + produced.filter { it !in rhs.consumed }.toSet() + rhs.produced) return combined } @@ -120,15 +121,7 @@ interface VaultService { * Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for * which we have no cash evaluate to null (not present in map), not 0. */ - @Suppress("UNCHECKED_CAST") val cashBalances: Map> - get() = currentVault.states. - // Select the states we own which are cash, ignore the rest, take the amounts. - mapNotNull { (it.state.data as? FungibleAsset)?.amount }. - // Turn into a Map> like { GBP -> (£100, £500, etc), USD -> ($2000, $50) } - groupBy { it.token.product }. - // Collapse to Map by summing all the amounts of the same currency together. - mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() } /** * Atomically get the current vault and a stream of updates. Note that the Observable buffers updates until the @@ -172,7 +165,7 @@ interface VaultService { */ fun whenConsumed(ref: StateRef): ListenableFuture { val future = SettableFuture.create() - updates.filter { ref in it.consumed }.first().subscribe { + updates.filter { it.consumed.any { it.ref == ref } }.first().subscribe { future.set(it) } return future diff --git a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt index 60731e0315..8964357beb 100644 --- a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt @@ -46,7 +46,7 @@ class VaultUpdateTests { @Test fun `something plus nothing is something`() { - val before = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) + val before = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) val after = before + Vault.NoUpdate assertEquals(before, after) } @@ -54,32 +54,32 @@ class VaultUpdateTests { @Test fun `nothing plus something is something`() { val before = Vault.NoUpdate - val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) - val expected = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) + val expected = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) assertEquals(expected, after) } @Test fun `something plus consume state 0 is something without state 0 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateRef0), setOf()) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef1)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0), setOf()) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef1)) assertEquals(expected, after) } @Test fun `something plus produce state 4 is something with additional state 4 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) val after = before + Vault.Update(setOf(), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) assertEquals(expected, after) } @Test fun `something plus consume states 0 and 1, and produce state 4, is something without state 0 and 1 outputs and only state 4 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef4)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef4)) assertEquals(expected, after) } } diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 5ef9163882..ac32fbf048 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -42,7 +42,7 @@ abstract class ServiceHubInternal : PluginServiceHub { abstract val schemaService: SchemaService abstract override val networkService: MessagingServiceInternal - + /** * Given a list of [SignedTransaction]s, writes them to the given storage for validated transactions and then * sends them to the vault for further processing. This is intended for implementations to call from diff --git a/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt b/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt index 5babdb4ea6..8b8f7beec0 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt @@ -14,7 +14,7 @@ import net.corda.node.services.api.ServiceHubInternal class ScheduledActivityObserver(val services: ServiceHubInternal) { init { services.vaultService.rawUpdates.subscribe { update -> - update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it) } + update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it.ref) } update.produced.forEach { scheduleStateActivity(it, services.flowLogicRefFactory) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 72946d5272..bce32f0b26 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -51,6 +51,11 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT override fun toString() = "$txnId: $note" } + private object CashBalanceTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_cash_balances") { + val currency = varchar("currency", 3) + val amount = long("amount") + } + private object TransactionNotesTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_txn_notes") { val txnId = secureHash("txnId").index() val note = text("note") @@ -80,6 +85,19 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT } } + val cashBalances = object : AbstractJDBCHashMap, CashBalanceTable>(CashBalanceTable) { + override fun keyFromRow(row: ResultRow): Currency = Currency.getInstance(row[table.currency]) + override fun valueFromRow(row: ResultRow): Amount = Amount(row[table.amount], keyFromRow(row)) + + override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { + insert[table.currency] = entry.key.currencyCode + } + + override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { + insert[table.amount] = entry.value.quantity + } + } + val _updatesPublisher = PublishSubject.create() val _rawUpdatesPublisher = PublishSubject.create() // For use during publishing only. @@ -97,15 +115,38 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT fun recordUpdate(update: Vault.Update): Vault.Update { if (update != Vault.NoUpdate) { val producedStateRefs = update.produced.map { it.ref } - val consumedStateRefs = update.consumed + val consumedStateRefs = update.consumed.map { it.ref } log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } unconsumedStates.removeAll(consumedStateRefs) unconsumedStates.addAll(producedStateRefs) } return update } + + // TODO: consider moving this logic outside the vault + fun maybeUpdateCashBalances(update: Vault.Update) { + if (update.containsType()) { + val consumed = sumCashStates(update.consumed) + val produced = sumCashStates(update.produced) + (produced.keys + consumed.keys).map { currency -> + val producedAmount = produced[currency] ?: Amount(0, currency) + val consumedAmount = consumed[currency] ?: Amount(0, currency) + val currentBalance = cashBalances[currency] ?: Amount(0, currency) + cashBalances[currency] = currentBalance + producedAmount - consumedAmount + } + } + } + + @Suppress("UNCHECKED_CAST") + private fun sumCashStates(states: Iterable>): Map> { + return states.mapNotNull { (it.state.data as? FungibleAsset)?.amount } + .groupBy { it.token.product } + .mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() } + } }) + override val cashBalances: Map> get() = mutex.locked { HashMap(cashBalances) } + override val currentVault: Vault get() = mutex.locked { Vault(allUnconsumedStates()) } override val rawUpdates: Observable @@ -134,6 +175,7 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT if (netDelta != Vault.NoUpdate) { mutex.locked { recordUpdate(netDelta) + maybeUpdateCashBalances(netDelta) updatesPublisher.onNext(netDelta) } } @@ -278,22 +320,27 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT map { tx.outRef(it.data) } // Now calculate the states that are being spent by this transaction. - val consumed = tx.inputs.toHashSet() + val consumedRefs = tx.inputs.toHashSet() // We use Guava union here as it's lazy for contains() which is how retainAll() is implemented. // i.e. retainAll() iterates over consumed, checking contains() on the parameter. Sets.union() does not physically create // a new collection and instead contains() just checks the contains() of both parameters, and so we don't end up // iterating over all (a potentially very large) unconsumedStates at any point. mutex.locked { - consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates)) + consumedRefs.retainAll(Sets.union(netDelta.produced, unconsumedStates)) } // Is transaction irrelevant? - if (consumed.isEmpty() && ourNewStates.isEmpty()) { + if (consumedRefs.isEmpty() && ourNewStates.isEmpty()) { log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" } return Vault.NoUpdate } - return Vault.Update(consumed, ourNewStates.toHashSet()) + val consumedStates = consumedRefs.map { + val state = services.loadState(it) + StateAndRef(state, it) + }.toSet() + + return Vault.Update(consumedStates, ourNewStates.toHashSet()) } private fun isRelevant(state: ContractState, ourKeys: Set): Boolean {