Refactor the wallet code:

- Rename NodeWalletService to InMemoryWalletService and move into the core module where it's available for unit testing.
- Make a new NodeWalletService that just inherits from InMemoryWalletService and doesn't customise it at all, for now.
- Take the cash specific functionality out of Wallet and into an extension property in the Cash contract (this compiles as CashKt.getCashBalance(wallet) for java users).
- Return the generated states in the fillWalletWithTestCash function.
This commit is contained in:
Mike Hearn 2016-06-21 20:35:34 +02:00
parent f3d4639059
commit 7ee6bd05ce
10 changed files with 160 additions and 188 deletions

View File

@ -5,9 +5,9 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.newSecureRandom
import com.r3corda.core.crypto.toStringShort
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.utilities.Emoji
import java.security.PublicKey
import java.security.SecureRandom
import java.util.*
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -224,3 +224,15 @@ fun Iterable<ContractState>.sumCashOrNull() = filterIsInstance<Cash.State>().map
/** Sums the cash states in the list, returning zero of the given currency if there are none. */
fun Iterable<ContractState>.sumCashOrZero(currency: Issued<Currency>) = filterIsInstance<Cash.State>().map { it.amount }.sumOrZero<Issued<Currency>>(currency)
/**
* Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null (not present in map), not 0.
*/
val Wallet.cashBalances: Map<Currency, Amount<Currency>> get() = states.
// Select the states we own which are cash, ignore the rest, take the amounts.
mapNotNull { (it.state.data as? Cash.State)?.amount }.
// Turn into a Map<Currency, List<Amount>> like { GBP -> (£100, £500, etc), USD -> ($2000, $50) }
groupBy { it.token.product }.
// Collapse to Map<Currency, Amount> by summing all the amounts of the same currency together.
mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() }

View File

@ -2,7 +2,10 @@
package com.r3corda.contracts.testing
import com.r3corda.contracts.cash.Cash
import com.r3corda.core.contracts.*
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Issued
import com.r3corda.core.contracts.SignedTransaction
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.Wallet
@ -56,10 +59,7 @@ fun ServiceHub.fillWithSomeTestCash(howMuch: Amount<Currency>,
stx.tx.outputs.indices.map { i -> stx.tx.outRef<Cash.State>(i) }
}
return object : Wallet() {
override val states: List<StateAndRef<ContractState>> = states
override val cashBalances: Map<Currency, Amount<Currency>> = mapOf(howMuch.token to howMuch)
}
return Wallet(states)
}
private fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray {

View File

@ -6,7 +6,6 @@ import com.r3corda.core.crypto.SecureHash
import java.security.KeyPair
import java.security.PrivateKey
import java.security.PublicKey
import java.util.*
/**
* Postfix for base topics when sending a request to a service.
@ -26,18 +25,10 @@ val TOPIC_DEFAULT_POSTFIX = ".0"
*
* This absract class has no references to Cash contracts.
*/
abstract class Wallet {
abstract val states: List<StateAndRef<ContractState>>
class Wallet(val states: List<StateAndRef<ContractState>>) {
@Suppress("UNCHECKED_CAST")
inline fun <reified T : OwnableState> statesOfType() = states.filter { it.state.data is T } as List<StateAndRef<T>>
/**
* Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null (not present in map), not 0.
*/
abstract val cashBalances: Map<Currency, Amount<Currency>>
/**
* Represents an update observed by the Wallet that will be notified to observers. Include the [StateRef]s of
* transaction outputs that were consumed (inputs) and the [ContractState]s produced (outputs) to/by the transaction
@ -82,12 +73,6 @@ interface WalletService {
*/
val currentWallet: Wallet
/**
* Returns a snapshot of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null, not 0.
*/
val cashBalances: Map<Currency, Amount<Currency>>
/**
* Returns a snapshot of the heads of LinearStates
*/

View File

@ -0,0 +1,123 @@
package com.r3corda.core.testing
import com.r3corda.core.ThreadBox
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.node.services.WalletService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import rx.Observable
import rx.subjects.PublishSubject
import java.security.PublicKey
import java.util.*
import javax.annotation.concurrent.ThreadSafe
/**
* This class implements a simple, in memory wallet that tracks states that are owned by us, and also has a convenience
* method to auto-generate some self-issued cash states that can be used for test trading. A real wallet would persist
* states relevant to us into a database and once such a wallet is implemented, this scaffolding can be removed.
*/
@ThreadSafe
open class InMemoryWalletService(private val services: ServiceHub) : SingletonSerializeAsToken(), WalletService {
private val log = loggerFor<InMemoryWalletService>()
// Variables inside InnerState are protected with a lock by the ThreadBox and aren't in scope unless you're
// inside mutex.locked {} code block. So we can't forget to take the lock unless we accidentally leak a reference
// to wallet somewhere.
private class InnerState {
var wallet = Wallet(emptyList<StateAndRef<OwnableState>>())
}
private val mutex = ThreadBox(InnerState())
override val currentWallet: Wallet get() = mutex.locked { wallet }
private val _updatesPublisher = PublishSubject.create<Wallet.Update>()
override val updates: Observable<Wallet.Update>
get() = _updatesPublisher
/**
* Returns a snapshot of the heads of LinearStates
*/
override val linearHeads: Map<SecureHash, StateAndRef<LinearState>>
get() = mutex.locked { wallet }.let { wallet ->
wallet.states.filterStatesOfType<LinearState>().associateBy { it.state.data.thread }.mapValues { it.value }
}
override fun notifyAll(txns: Iterable<WireTransaction>): Wallet {
val ourKeys = services.keyManagementService.keys.keys
// Note how terribly incomplete this all is!
//
// - We don't notify anyone of anything, there are no event listeners.
// - We don't handle or even notice invalidations due to double spends of things in our wallet.
// - We have no concept of confidence (for txns where there is no definite finality).
// - No notification that keys are used, for the case where we observe a spend of our own states.
// - No ability to create complex spends.
// - No logging or tracking of how the wallet got into this state.
// - No persistence.
// - Does tx relevancy calculation and key management need to be interlocked? Probably yes.
//
// ... and many other things .... (Wallet.java in bitcoinj is several thousand lines long)
var netDelta = Wallet.NoUpdate
val changedWallet = mutex.locked {
// Starting from the current wallet, keep applying the transaction updates, calculating a new Wallet each
// time, until we get to the result (this is perhaps a bit inefficient, but it's functional and easily
// unit tested).
val walletAndNetDelta = txns.fold(Pair(currentWallet, Wallet.NoUpdate)) { walletAndDelta, tx ->
val (wallet, delta) = walletAndDelta.first.update(tx, ourKeys)
val combinedDelta = delta + walletAndDelta.second
Pair(wallet, combinedDelta)
}
wallet = walletAndNetDelta.first
netDelta = walletAndNetDelta.second
return@locked wallet
}
if (netDelta != Wallet.NoUpdate) {
_updatesPublisher.onNext(netDelta)
}
return changedWallet
}
private fun isRelevant(state: ContractState, ourKeys: Set<PublicKey>): Boolean {
return if (state is OwnableState) {
state.owner in ourKeys
} else if (state is LinearState) {
// It's potentially of interest to the wallet
state.isRelevant(ourKeys)
} else {
false
}
}
private fun Wallet.update(tx: WireTransaction, ourKeys: Set<PublicKey>): Pair<Wallet, Wallet.Update> {
val ourNewStates = tx.outputs.
filter { isRelevant(it.data, ourKeys) }.
map { tx.outRef<ContractState>(it.data) }
// Now calculate the states that are being spent by this transaction.
val consumed: Set<StateRef> = states.map { it.ref }.intersect(tx.inputs)
// Is transaction irrelevant?
if (consumed.isEmpty() && ourNewStates.isEmpty()) {
log.trace { "tx ${tx.id} was irrelevant to this wallet, ignoring" }
return Pair(this, Wallet.NoUpdate)
}
val change = Wallet.Update(consumed, HashSet(ourNewStates))
// And calculate the new wallet.
val newStates = states.filter { it.ref !in consumed } + ourNewStates
log.trace {
"Applied tx ${tx.id.prefixChars()} to the wallet: consumed ${consumed.size} states and added ${newStates.size}"
}
return Pair(Wallet(newStates), change)
}
}

View File

@ -1,6 +1,7 @@
package com.r3corda.node.services.wallet
import com.codahale.metrics.Gauge
import com.r3corda.contracts.cash.cashBalances
import com.r3corda.core.node.services.Wallet
import com.r3corda.node.services.api.ServiceHubInternal
import java.util.*
@ -9,7 +10,6 @@ import java.util.*
* This class observes the wallet and reflect current cash balances as exposed metrics in the monitoring service.
*/
class CashBalanceAsMetricsObserver(val serviceHubInternal: ServiceHubInternal) {
init {
// TODO: Need to consider failure scenarios. This needs to run if the TX is successfully recorded
serviceHubInternal.walletService.updates.subscribe { update ->

View File

@ -1,129 +1,10 @@
package com.r3corda.node.services.wallet
import com.r3corda.core.ThreadBox
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.node.services.WalletService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal
import rx.Observable
import rx.subjects.PublishSubject
import java.security.PublicKey
import java.util.*
import javax.annotation.concurrent.ThreadSafe
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.testing.InMemoryWalletService
/**
* This class implements a simple, in memory wallet that tracks states that are owned by us, and also has a convenience
* method to auto-generate some self-issued cash states that can be used for test trading. A real wallet would persist
* states relevant to us into a database and once such a wallet is implemented, this scaffolding can be removed.
* Currently, the node wallet service is just the in-memory wallet service until we have finished evaluating and
* selecting a persistence layer (probably an ORM over a SQL DB).
*/
@ThreadSafe
class NodeWalletService(private val services: ServiceHubInternal) : SingletonSerializeAsToken(), WalletService {
private val log = loggerFor<NodeWalletService>()
// Variables inside InnerState are protected with a lock by the ThreadBox and aren't in scope unless you're
// inside mutex.locked {} code block. So we can't forget to take the lock unless we accidentally leak a reference
// to wallet somewhere.
private class InnerState {
var wallet: Wallet = WalletImpl(emptyList<StateAndRef<OwnableState>>())
}
private val mutex = ThreadBox(InnerState())
override val currentWallet: Wallet get() = mutex.locked { wallet }
private val _updatesPublisher = PublishSubject.create<Wallet.Update>()
override val updates: Observable<Wallet.Update>
get() = _updatesPublisher
/**
* Returns a snapshot of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null, not 0.
*/
override val cashBalances: Map<Currency, Amount<Currency>> get() = mutex.locked { wallet }.cashBalances
/**
* Returns a snapshot of the heads of LinearStates
*/
override val linearHeads: Map<SecureHash, StateAndRef<LinearState>>
get() = mutex.locked { wallet }.let { wallet ->
wallet.states.filterStatesOfType<LinearState>().associateBy { it.state.data.thread }.mapValues { it.value }
}
override fun notifyAll(txns: Iterable<WireTransaction>): Wallet {
val ourKeys = services.keyManagementService.keys.keys
// Note how terribly incomplete this all is!
//
// - We don't notify anyone of anything, there are no event listeners.
// - We don't handle or even notice invalidations due to double spends of things in our wallet.
// - We have no concept of confidence (for txns where there is no definite finality).
// - No notification that keys are used, for the case where we observe a spend of our own states.
// - No ability to create complex spends.
// - No logging or tracking of how the wallet got into this state.
// - No persistence.
// - Does tx relevancy calculation and key management need to be interlocked? Probably yes.
//
// ... and many other things .... (Wallet.java in bitcoinj is several thousand lines long)
var netDelta = Wallet.NoUpdate
val changedWallet = mutex.locked {
// Starting from the current wallet, keep applying the transaction updates, calculating a new Wallet each
// time, until we get to the result (this is perhaps a bit inefficient, but it's functional and easily
// unit tested).
val walletAndNetDelta = txns.fold(Pair(currentWallet, Wallet.NoUpdate)) { walletAndDelta, tx ->
val (wallet, delta) = walletAndDelta.first.update(tx, ourKeys)
val combinedDelta = delta + walletAndDelta.second
Pair(wallet, combinedDelta)
}
wallet = walletAndNetDelta.first
netDelta = walletAndNetDelta.second
return@locked wallet
}
if (netDelta != Wallet.NoUpdate) {
_updatesPublisher.onNext(netDelta)
}
return changedWallet
}
private fun isRelevant(state: ContractState, ourKeys: Set<PublicKey>): Boolean {
return if (state is OwnableState) {
state.owner in ourKeys
} else if (state is LinearState) {
// It's potentially of interest to the wallet
state.isRelevant(ourKeys)
} else {
false
}
}
private fun Wallet.update(tx: WireTransaction, ourKeys: Set<PublicKey>): Pair<Wallet, Wallet.Update> {
val ourNewStates = tx.outputs.
filter { isRelevant(it.data, ourKeys) }.
map { tx.outRef<ContractState>(it.data) }
// Now calculate the states that are being spent by this transaction.
val consumed: Set<StateRef> = states.map { it.ref }.intersect(tx.inputs)
// Is transaction irrelevant?
if (consumed.isEmpty() && ourNewStates.isEmpty()) {
log.trace { "tx ${tx.id} was irrelevant to this wallet, ignoring" }
return Pair(this, Wallet.NoUpdate)
}
val change = Wallet.Update(consumed, HashSet(ourNewStates))
// And calculate the new wallet.
val newStates = states.filter { it.ref !in consumed } + ourNewStates
log.trace {
"Applied tx ${tx.id.prefixChars()} to the wallet: consumed ${consumed.size} states and added ${newStates.size}"
}
return Pair(WalletImpl(newStates), change)
}
}
class NodeWalletService(services: ServiceHub) : InMemoryWalletService(services)

View File

@ -1,32 +0,0 @@
package com.r3corda.node.services.wallet
import com.r3corda.contracts.cash.Cash
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.ContractState
import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.contracts.sumOrThrow
import com.r3corda.core.node.services.Wallet
import java.util.*
/**
* A wallet (name may be temporary) wraps a set of states that are useful for us to keep track of, for instance,
* because we own them. This class represents an immutable, stable state of a wallet: it is guaranteed not to
* change out from underneath you, even though the canonical currently-best-known wallet may change as we learn
* about new transactions from our peers and generate new transactions that consume states ourselves.
*
* This concrete implementation references Cash contracts.
*/
class WalletImpl(override val states: List<StateAndRef<ContractState>>) : Wallet() {
/**
* Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for
* which we have no cash evaluate to null (not present in map), not 0.
*/
override val cashBalances: Map<Currency, Amount<Currency>> get() = states.
// Select the states we own which are cash, ignore the rest, take the amounts.
mapNotNull { (it.state.data as? Cash.State)?.amount }.
// Turn into a Map<Currency, List<Amount>> like { GBP -> (£100, £500, etc), USD -> ($2000, $50) }
groupBy { it.token.product }.
// Collapse to Map<Currency, Amount> by summing all the amounts of the same currency together.
mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() }
}

View File

@ -28,7 +28,6 @@ import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.services.wallet.WalletImpl
import com.r3corda.protocols.TwoPartyTradeProtocol
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
@ -458,7 +457,7 @@ class TwoPartyTradeProtocolTests {
arg(MEGA_CORP_PUBKEY) { Cash.Commands.Move() }
}
val wallet = WalletImpl(listOf<StateAndRef<Cash.State>>(lookup("bob cash 1"), lookup("bob cash 2")))
val wallet = Wallet(listOf<StateAndRef<Cash.State>>(lookup("bob cash 1"), lookup("bob cash 2")))
return Pair(wallet, listOf(eb1, bc1, bc2))
}
@ -478,7 +477,7 @@ class TwoPartyTradeProtocolTests {
attachment(attachmentID)
}
val wallet = WalletImpl(listOf<StateAndRef<Cash.State>>(lookup("alice's paper")))
val wallet = Wallet(listOf<StateAndRef<Cash.State>>(lookup("alice's paper")))
return Pair(wallet, listOf(ap))
}

View File

@ -1,6 +1,7 @@
package com.r3corda.node.services
import com.r3corda.contracts.cash.Cash
import com.r3corda.contracts.cash.cashBalances
import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.contracts.*
import com.r3corda.core.node.ServiceHub
@ -17,7 +18,9 @@ import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertNull
class NodeWalletServiceTest {
// TODO: Move this to the cash contract tests once mock services are further split up.
class WalletWithCashTest {
val kms = MockKeyManagementService(ALICE_KEY)
@Before
@ -83,13 +86,13 @@ class NodeWalletServiceTest {
signWith(DUMMY_NOTARY_KEY)
}.toSignedTransaction()
assertNull(wallet.cashBalances[USD])
assertNull(wallet.currentWallet.cashBalances[USD])
wallet.notify(usefulTX.tx)
assertEquals(100.DOLLARS, wallet.cashBalances[USD])
assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD])
wallet.notify(irrelevantTX.tx)
assertEquals(100.DOLLARS, wallet.cashBalances[USD])
assertEquals(100.DOLLARS, wallet.currentWallet.cashBalances[USD])
wallet.notify(spendTX.tx)
assertEquals(20.DOLLARS, wallet.cashBalances[USD])
assertEquals(20.DOLLARS, wallet.currentWallet.cashBalances[USD])
// TODO: Flesh out these tests as needed.
}

View File

@ -3,6 +3,7 @@ package com.r3corda.demos
import co.paralleluniverse.fibers.Suspendable
import com.google.common.net.HostAndPort
import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.cash.cashBalances
import com.r3corda.contracts.testing.fillWithSomeTestCash
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party
@ -266,7 +267,7 @@ class TraderDemoProtocolBuyer(private val attachmentsPath: Path,
}
private fun logBalance() {
val balances = serviceHub.walletService.cashBalances.entries.map { "${it.key.currencyCode} ${it.value}" }
val balances = serviceHub.walletService.currentWallet.cashBalances.entries.map { "${it.key.currencyCode} ${it.value}" }
logger.info("Remaining balance: ${balances.joinToString()}")
}