diff --git a/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt b/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt index affa3aa7b1..33ec550322 100644 --- a/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt +++ b/client/src/main/kotlin/net/corda/client/model/ContractStateModel.kt @@ -8,13 +8,12 @@ import net.corda.client.fxutils.map import net.corda.contracts.asset.Cash import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.StateRef import net.corda.core.node.services.Vault import rx.Observable data class Diff( val added: Collection>, - val removed: Collection + val removed: Collection> ) /** @@ -27,11 +26,10 @@ class ContractStateModel { Diff(it.produced, it.consumed) } private val cashStatesDiff: Observable> = contractStatesDiff.map { - // We can't filter removed hashes here as we don't have type info - Diff(it.added.filterCashStateAndRefs(), it.removed) + Diff(it.added.filterCashStateAndRefs(), it.removed.filterCashStateAndRefs()) } val cashStates: ObservableList> = cashStatesDiff.fold(FXCollections.observableArrayList()) { list, statesDiff -> - list.removeIf { it.ref in statesDiff.removed } + list.removeIf { it in statesDiff.removed } list.addAll(statesDiff.added) } diff --git a/core/src/main/kotlin/net/corda/core/node/services/Services.kt b/core/src/main/kotlin/net/corda/core/node/services/Services.kt index 5f5215f1ac..5103ed5a9c 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/Services.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/Services.kt @@ -50,7 +50,10 @@ class Vault(val states: Iterable>) { * If the vault observes multiple transactions simultaneously, where some transactions consume the outputs of some of the * other transactions observed, then the changes are observed "net" of those. */ - data class Update(val consumed: Set, val produced: Set>) { + data class Update(val consumed: Set>, val produced: Set>) { + /** Checks whether the update contains a state of the specified type. */ + inline fun containsType() = consumed.any { it.state.data is T } || produced.any { it.state.data is T } + /** * Combine two updates into a single update with the combined inputs and outputs of the two updates but net * any outputs of the left-hand-side (this) that are consumed by the inputs of the right-hand-side (rhs). @@ -58,12 +61,10 @@ class Vault(val states: Iterable>) { * i.e. the net effect in terms of state live-ness of receiving the combined update is the same as receiving this followed by rhs. */ operator fun plus(rhs: Update): Update { - val previouslyProduced = produced.map { it.ref } - val previouslyConsumed = consumed val combined = Vault.Update( - previouslyConsumed + (rhs.consumed - previouslyProduced), + consumed + (rhs.consumed - produced), // The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations. - produced.filter { it.ref !in rhs.consumed }.toSet() + rhs.produced) + produced.filter { it !in rhs.consumed }.toSet() + rhs.produced) return combined } @@ -120,15 +121,7 @@ interface VaultService { * Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for * which we have no cash evaluate to null (not present in map), not 0. */ - @Suppress("UNCHECKED_CAST") val cashBalances: Map> - get() = currentVault.states. - // Select the states we own which are cash, ignore the rest, take the amounts. - mapNotNull { (it.state.data as? FungibleAsset)?.amount }. - // Turn into a Map> like { GBP -> (£100, £500, etc), USD -> ($2000, $50) } - groupBy { it.token.product }. - // Collapse to Map by summing all the amounts of the same currency together. - mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() } /** * Atomically get the current vault and a stream of updates. Note that the Observable buffers updates until the @@ -172,7 +165,7 @@ interface VaultService { */ fun whenConsumed(ref: StateRef): ListenableFuture { val future = SettableFuture.create() - updates.filter { ref in it.consumed }.first().subscribe { + updates.filter { it.consumed.any { it.ref == ref } }.first().subscribe { future.set(it) } return future diff --git a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt index 60731e0315..8964357beb 100644 --- a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt @@ -46,7 +46,7 @@ class VaultUpdateTests { @Test fun `something plus nothing is something`() { - val before = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) + val before = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) val after = before + Vault.NoUpdate assertEquals(before, after) } @@ -54,32 +54,32 @@ class VaultUpdateTests { @Test fun `nothing plus something is something`() { val before = Vault.NoUpdate - val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) - val expected = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) + val expected = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) assertEquals(expected, after) } @Test fun `something plus consume state 0 is something without state 0 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateRef0), setOf()) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef1)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0), setOf()) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef1)) assertEquals(expected, after) } @Test fun `something plus produce state 4 is something with additional state 4 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) val after = before + Vault.Update(setOf(), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) assertEquals(expected, after) } @Test fun `something plus consume states 0 and 1, and produce state 4, is something without state 0 and 1 outputs and only state 4 output`() { - val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef4)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef4)) assertEquals(expected, after) } } diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 5ef9163882..ac32fbf048 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -42,7 +42,7 @@ abstract class ServiceHubInternal : PluginServiceHub { abstract val schemaService: SchemaService abstract override val networkService: MessagingServiceInternal - + /** * Given a list of [SignedTransaction]s, writes them to the given storage for validated transactions and then * sends them to the vault for further processing. This is intended for implementations to call from diff --git a/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt b/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt index 5babdb4ea6..8b8f7beec0 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/ScheduledActivityObserver.kt @@ -14,7 +14,7 @@ import net.corda.node.services.api.ServiceHubInternal class ScheduledActivityObserver(val services: ServiceHubInternal) { init { services.vaultService.rawUpdates.subscribe { update -> - update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it) } + update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it.ref) } update.produced.forEach { scheduleStateActivity(it, services.flowLogicRefFactory) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 72946d5272..bce32f0b26 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -51,6 +51,11 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT override fun toString() = "$txnId: $note" } + private object CashBalanceTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_cash_balances") { + val currency = varchar("currency", 3) + val amount = long("amount") + } + private object TransactionNotesTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_txn_notes") { val txnId = secureHash("txnId").index() val note = text("note") @@ -80,6 +85,19 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT } } + val cashBalances = object : AbstractJDBCHashMap, CashBalanceTable>(CashBalanceTable) { + override fun keyFromRow(row: ResultRow): Currency = Currency.getInstance(row[table.currency]) + override fun valueFromRow(row: ResultRow): Amount = Amount(row[table.amount], keyFromRow(row)) + + override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { + insert[table.currency] = entry.key.currencyCode + } + + override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { + insert[table.amount] = entry.value.quantity + } + } + val _updatesPublisher = PublishSubject.create() val _rawUpdatesPublisher = PublishSubject.create() // For use during publishing only. @@ -97,15 +115,38 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT fun recordUpdate(update: Vault.Update): Vault.Update { if (update != Vault.NoUpdate) { val producedStateRefs = update.produced.map { it.ref } - val consumedStateRefs = update.consumed + val consumedStateRefs = update.consumed.map { it.ref } log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } unconsumedStates.removeAll(consumedStateRefs) unconsumedStates.addAll(producedStateRefs) } return update } + + // TODO: consider moving this logic outside the vault + fun maybeUpdateCashBalances(update: Vault.Update) { + if (update.containsType()) { + val consumed = sumCashStates(update.consumed) + val produced = sumCashStates(update.produced) + (produced.keys + consumed.keys).map { currency -> + val producedAmount = produced[currency] ?: Amount(0, currency) + val consumedAmount = consumed[currency] ?: Amount(0, currency) + val currentBalance = cashBalances[currency] ?: Amount(0, currency) + cashBalances[currency] = currentBalance + producedAmount - consumedAmount + } + } + } + + @Suppress("UNCHECKED_CAST") + private fun sumCashStates(states: Iterable>): Map> { + return states.mapNotNull { (it.state.data as? FungibleAsset)?.amount } + .groupBy { it.token.product } + .mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() } + } }) + override val cashBalances: Map> get() = mutex.locked { HashMap(cashBalances) } + override val currentVault: Vault get() = mutex.locked { Vault(allUnconsumedStates()) } override val rawUpdates: Observable @@ -134,6 +175,7 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT if (netDelta != Vault.NoUpdate) { mutex.locked { recordUpdate(netDelta) + maybeUpdateCashBalances(netDelta) updatesPublisher.onNext(netDelta) } } @@ -278,22 +320,27 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT map { tx.outRef(it.data) } // Now calculate the states that are being spent by this transaction. - val consumed = tx.inputs.toHashSet() + val consumedRefs = tx.inputs.toHashSet() // We use Guava union here as it's lazy for contains() which is how retainAll() is implemented. // i.e. retainAll() iterates over consumed, checking contains() on the parameter. Sets.union() does not physically create // a new collection and instead contains() just checks the contains() of both parameters, and so we don't end up // iterating over all (a potentially very large) unconsumedStates at any point. mutex.locked { - consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates)) + consumedRefs.retainAll(Sets.union(netDelta.produced, unconsumedStates)) } // Is transaction irrelevant? - if (consumed.isEmpty() && ourNewStates.isEmpty()) { + if (consumedRefs.isEmpty() && ourNewStates.isEmpty()) { log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" } return Vault.NoUpdate } - return Vault.Update(consumed, ourNewStates.toHashSet()) + val consumedStates = consumedRefs.map { + val state = services.loadState(it) + StateAndRef(state, it) + }.toSet() + + return Vault.Update(consumedStates, ourNewStates.toHashSet()) } private fun isRelevant(state: ContractState, ourKeys: Set): Boolean {