Made vault updates contain full StateAndRef in the consumed set (instead of just StateRef). This allows subscribers to check whether the update contains relevant states.

Cash balances are now calculated by keeping only the aggregate values (it no longer needs to iterate through all states in the vault).
This commit is contained in:
Andrius Dagys 2016-12-14 15:25:43 +00:00
parent adc70569b1
commit 7cb4cbcad4
6 changed files with 75 additions and 37 deletions

View File

@ -8,13 +8,12 @@ import net.corda.client.fxutils.map
import net.corda.contracts.asset.Cash
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.node.services.Vault
import rx.Observable
data class Diff<out T : ContractState>(
val added: Collection<StateAndRef<T>>,
val removed: Collection<StateRef>
val removed: Collection<StateAndRef<T>>
)
/**
@ -27,11 +26,10 @@ class ContractStateModel {
Diff(it.produced, it.consumed)
}
private val cashStatesDiff: Observable<Diff<Cash.State>> = contractStatesDiff.map {
// We can't filter removed hashes here as we don't have type info
Diff(it.added.filterCashStateAndRefs(), it.removed)
Diff(it.added.filterCashStateAndRefs(), it.removed.filterCashStateAndRefs())
}
val cashStates: ObservableList<StateAndRef<Cash.State>> = cashStatesDiff.fold(FXCollections.observableArrayList()) { list, statesDiff ->
list.removeIf { it.ref in statesDiff.removed }
list.removeIf { it in statesDiff.removed }
list.addAll(statesDiff.added)
}

View File

@ -50,7 +50,10 @@ class Vault(val states: Iterable<StateAndRef<ContractState>>) {
* If the vault observes multiple transactions simultaneously, where some transactions consume the outputs of some of the
* other transactions observed, then the changes are observed "net" of those.
*/
data class Update(val consumed: Set<StateRef>, val produced: Set<StateAndRef<ContractState>>) {
data class Update(val consumed: Set<StateAndRef<ContractState>>, val produced: Set<StateAndRef<ContractState>>) {
/** Checks whether the update contains a state of the specified type. */
inline fun <reified T : ContractState> containsType() = consumed.any { it.state.data is T } || produced.any { it.state.data is T }
/**
* Combine two updates into a single update with the combined inputs and outputs of the two updates but net
* any outputs of the left-hand-side (this) that are consumed by the inputs of the right-hand-side (rhs).
@ -58,12 +61,10 @@ class Vault(val states: Iterable<StateAndRef<ContractState>>) {
* i.e. the net effect in terms of state live-ness of receiving the combined update is the same as receiving this followed by rhs.
*/
operator fun plus(rhs: Update): Update {
val previouslyProduced = produced.map { it.ref }
val previouslyConsumed = consumed
val combined = Vault.Update(
previouslyConsumed + (rhs.consumed - previouslyProduced),
consumed + (rhs.consumed - produced),
// The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations.
produced.filter { it.ref !in rhs.consumed }.toSet() + rhs.produced)
produced.filter { it !in rhs.consumed }.toSet() + rhs.produced)
return combined
}
@ -120,15 +121,7 @@ interface VaultService {
* Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null (not present in map), not 0.
*/
@Suppress("UNCHECKED_CAST")
val cashBalances: Map<Currency, Amount<Currency>>
get() = currentVault.states.
// Select the states we own which are cash, ignore the rest, take the amounts.
mapNotNull { (it.state.data as? FungibleAsset<Currency>)?.amount }.
// Turn into a Map<Currency, List<Amount>> like { GBP -> (£100, £500, etc), USD -> ($2000, $50) }
groupBy { it.token.product }.
// Collapse to Map<Currency, Amount> by summing all the amounts of the same currency together.
mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() }
/**
* Atomically get the current vault and a stream of updates. Note that the Observable buffers updates until the
@ -172,7 +165,7 @@ interface VaultService {
*/
fun whenConsumed(ref: StateRef): ListenableFuture<Vault.Update> {
val future = SettableFuture.create<Vault.Update>()
updates.filter { ref in it.consumed }.first().subscribe {
updates.filter { it.consumed.any { it.ref == ref } }.first().subscribe {
future.set(it)
}
return future

View File

@ -46,7 +46,7 @@ class VaultUpdateTests {
@Test
fun `something plus nothing is something`() {
val before = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3))
val before = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3))
val after = before + Vault.NoUpdate
assertEquals(before, after)
}
@ -54,32 +54,32 @@ class VaultUpdateTests {
@Test
fun `nothing plus something is something`() {
val before = Vault.NoUpdate
val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3))
val expected = Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef2, stateAndRef3))
val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3))
val expected = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3))
assertEquals(expected, after)
}
@Test
fun `something plus consume state 0 is something without state 0 output`() {
val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1))
val after = before + Vault.Update(setOf(stateRef0), setOf())
val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef1))
val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1))
val after = before + Vault.Update(setOf(stateAndRef0), setOf())
val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef1))
assertEquals(expected, after)
}
@Test
fun `something plus produce state 4 is something with additional state 4 output`() {
val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1))
val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1))
val after = before + Vault.Update(setOf(), setOf(stateAndRef4))
val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4))
val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4))
assertEquals(expected, after)
}
@Test
fun `something plus consume states 0 and 1, and produce state 4, is something without state 0 and 1 outputs and only state 4 output`() {
val before = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef0, stateAndRef1))
val after = before + Vault.Update(setOf(stateRef0, stateRef1), setOf(stateAndRef4))
val expected = Vault.Update(setOf(stateRef2, stateRef3), setOf(stateAndRef4))
val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1))
val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4))
val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef4))
assertEquals(expected, after)
}
}

View File

@ -42,7 +42,7 @@ abstract class ServiceHubInternal : PluginServiceHub {
abstract val schemaService: SchemaService
abstract override val networkService: MessagingServiceInternal
/**
* Given a list of [SignedTransaction]s, writes them to the given storage for validated transactions and then
* sends them to the vault for further processing. This is intended for implementations to call from

View File

@ -14,7 +14,7 @@ import net.corda.node.services.api.ServiceHubInternal
class ScheduledActivityObserver(val services: ServiceHubInternal) {
init {
services.vaultService.rawUpdates.subscribe { update ->
update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it) }
update.consumed.forEach { services.schedulerService.unscheduleStateActivity(it.ref) }
update.produced.forEach { scheduleStateActivity(it, services.flowLogicRefFactory) }
}
}

View File

@ -51,6 +51,11 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT
override fun toString() = "$txnId: $note"
}
private object CashBalanceTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_cash_balances") {
val currency = varchar("currency", 3)
val amount = long("amount")
}
private object TransactionNotesTable : JDBCHashedTable("${NODE_DATABASE_PREFIX}vault_txn_notes") {
val txnId = secureHash("txnId").index()
val note = text("note")
@ -80,6 +85,19 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT
}
}
val cashBalances = object : AbstractJDBCHashMap<Currency, Amount<Currency>, CashBalanceTable>(CashBalanceTable) {
override fun keyFromRow(row: ResultRow): Currency = Currency.getInstance(row[table.currency])
override fun valueFromRow(row: ResultRow): Amount<Currency> = Amount(row[table.amount], keyFromRow(row))
override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry<Currency, Amount<Currency>>, finalizables: MutableList<() -> Unit>) {
insert[table.currency] = entry.key.currencyCode
}
override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry<Currency, Amount<Currency>>, finalizables: MutableList<() -> Unit>) {
insert[table.amount] = entry.value.quantity
}
}
val _updatesPublisher = PublishSubject.create<Vault.Update>()
val _rawUpdatesPublisher = PublishSubject.create<Vault.Update>()
// For use during publishing only.
@ -97,15 +115,38 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT
fun recordUpdate(update: Vault.Update): Vault.Update {
if (update != Vault.NoUpdate) {
val producedStateRefs = update.produced.map { it.ref }
val consumedStateRefs = update.consumed
val consumedStateRefs = update.consumed.map { it.ref }
log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." }
unconsumedStates.removeAll(consumedStateRefs)
unconsumedStates.addAll(producedStateRefs)
}
return update
}
// TODO: consider moving this logic outside the vault
fun maybeUpdateCashBalances(update: Vault.Update) {
if (update.containsType<Cash.State>()) {
val consumed = sumCashStates(update.consumed)
val produced = sumCashStates(update.produced)
(produced.keys + consumed.keys).map { currency ->
val producedAmount = produced[currency] ?: Amount(0, currency)
val consumedAmount = consumed[currency] ?: Amount(0, currency)
val currentBalance = cashBalances[currency] ?: Amount(0, currency)
cashBalances[currency] = currentBalance + producedAmount - consumedAmount
}
}
}
@Suppress("UNCHECKED_CAST")
private fun sumCashStates(states: Iterable<StateAndRef<ContractState>>): Map<Currency, Amount<Currency>> {
return states.mapNotNull { (it.state.data as? FungibleAsset<Currency>)?.amount }
.groupBy { it.token.product }
.mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() }
}
})
override val cashBalances: Map<Currency, Amount<Currency>> get() = mutex.locked { HashMap(cashBalances) }
override val currentVault: Vault get() = mutex.locked { Vault(allUnconsumedStates()) }
override val rawUpdates: Observable<Vault.Update>
@ -134,6 +175,7 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT
if (netDelta != Vault.NoUpdate) {
mutex.locked {
recordUpdate(netDelta)
maybeUpdateCashBalances(netDelta)
updatesPublisher.onNext(netDelta)
}
}
@ -278,22 +320,27 @@ class NodeVaultService(private val services: ServiceHub) : SingletonSerializeAsT
map { tx.outRef<ContractState>(it.data) }
// Now calculate the states that are being spent by this transaction.
val consumed = tx.inputs.toHashSet()
val consumedRefs = tx.inputs.toHashSet()
// We use Guava union here as it's lazy for contains() which is how retainAll() is implemented.
// i.e. retainAll() iterates over consumed, checking contains() on the parameter. Sets.union() does not physically create
// a new collection and instead contains() just checks the contains() of both parameters, and so we don't end up
// iterating over all (a potentially very large) unconsumedStates at any point.
mutex.locked {
consumed.retainAll(Sets.union(netDelta.produced, unconsumedStates))
consumedRefs.retainAll(Sets.union(netDelta.produced, unconsumedStates))
}
// Is transaction irrelevant?
if (consumed.isEmpty() && ourNewStates.isEmpty()) {
if (consumedRefs.isEmpty() && ourNewStates.isEmpty()) {
log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" }
return Vault.NoUpdate
}
return Vault.Update(consumed, ourNewStates.toHashSet())
val consumedStates = consumedRefs.map {
val state = services.loadState(it)
StateAndRef(state, it)
}.toSet()
return Vault.Update(consumedStates, ourNewStates.toHashSet())
}
private fun isRelevant(state: ContractState, ourKeys: Set<PublicKey>): Boolean {