diff --git a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt index 70cd2d8500..8bb335b309 100644 --- a/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt +++ b/core/src/main/kotlin/com/r3corda/core/node/services/Services.kt @@ -23,7 +23,11 @@ val TOPIC_DEFAULT_POSTFIX = ".0" * change out from underneath you, even though the canonical currently-best-known wallet may change as we learn * about new transactions from our peers and generate new transactions that consume states ourselves. * - * This absract class has no references to Cash contracts. + * This abstract class has no references to Cash contracts. + * + * [states] Holds the list of states that are *active* and *relevant*. + * Active means they haven't been consumed yet (or we don't know about it). + * Relevant means they contain at least one of our pubkeys */ class Wallet(val states: List>) { @Suppress("UNCHECKED_CAST") diff --git a/core/src/main/kotlin/com/r3corda/core/testing/AlwaysSucceedContract.kt b/core/src/main/kotlin/com/r3corda/core/testing/AlwaysSucceedContract.kt new file mode 100644 index 0000000000..646ec06801 --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/testing/AlwaysSucceedContract.kt @@ -0,0 +1,10 @@ +package com.r3corda.core.testing + +import com.r3corda.core.contracts.Contract +import com.r3corda.core.contracts.TransactionForContract +import com.r3corda.core.crypto.SecureHash + +class AlwaysSucceedContract(override val legalContractReference: SecureHash = SecureHash.sha256("Always succeed contract")) : Contract { + override fun verify(tx: TransactionForContract) { + } +} diff --git a/core/src/main/kotlin/com/r3corda/core/testing/DummyLinearState.kt b/core/src/main/kotlin/com/r3corda/core/testing/DummyLinearState.kt new file mode 100644 index 0000000000..756e1c4494 --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/testing/DummyLinearState.kt @@ -0,0 +1,18 @@ +package com.r3corda.core.testing + +import com.r3corda.core.contracts.Contract +import com.r3corda.core.contracts.LinearState +import com.r3corda.core.crypto.SecureHash +import java.security.PublicKey +import java.util.* + +class DummyLinearState( + override val thread: SecureHash = SecureHash.randomSHA256(), + override val contract: Contract = AlwaysSucceedContract(), + override val participants: List = listOf(), + val nonce: SecureHash = SecureHash.randomSHA256()) : LinearState { + + override fun isRelevant(ourKeys: Set): Boolean { + return participants.any { ourKeys.contains(it) } + } +} diff --git a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt index 8c1ec72b30..45bf312a3d 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt @@ -22,6 +22,8 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe open class InMemoryWalletService(private val services: ServiceHub) : SingletonSerializeAsToken(), WalletService { + class ClashingThreads(threads: Set, transactions: Iterable) : + Exception("There are multiple linear head states after processing transactions $transactions. The clashing thread(s): $threads") private val log = loggerFor() // Variables inside InnerState are protected with a lock by the ThreadBox and aren't in scope unless you're @@ -44,7 +46,7 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe * Returns a snapshot of the heads of LinearStates */ override val linearHeads: Map> - get() = mutex.locked { wallet }.let { wallet -> + get() = currentWallet.let { wallet -> wallet.states.filterStatesOfType().associateBy { it.state.data.thread }.mapValues { it.value } } @@ -74,10 +76,17 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe val combinedDelta = delta + walletAndDelta.second Pair(wallet, combinedDelta) } + + val clashingThreads = walletAndNetDelta.first.clashingThreads + if (!clashingThreads.isEmpty()) { + throw ClashingThreads(clashingThreads, txns) + } + wallet = walletAndNetDelta.first netDelta = walletAndNetDelta.second return@locked wallet } + if (netDelta != Wallet.NoUpdate) { _updatesPublisher.onNext(netDelta) } @@ -120,4 +129,23 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe return Pair(Wallet(newStates), change) } -} \ No newline at end of file + + companion object { + + // Returns the set of LinearState threads that clash in the wallet + val Wallet.clashingThreads: Set get() { + val clashingThreads = HashSet() + val threadsSeen = HashSet() + for (linearState in states.filterStatesOfType()) { + val thread = linearState.state.data.thread + if (threadsSeen.contains(thread)) { + clashingThreads.add(thread) + } else { + threadsSeen.add(thread) + } + } + return clashingThreads + } + + } +} diff --git a/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt b/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt index 100a072288..1107f610ab 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/WalletWithCashTest.kt @@ -4,6 +4,7 @@ import com.r3corda.contracts.cash.Cash import com.r3corda.contracts.cash.cashBalances import com.r3corda.contracts.testing.fillWithSomeTestCash import com.r3corda.core.contracts.* +import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.testing.MockKeyManagementService import com.r3corda.core.node.services.testing.MockStorageService @@ -14,6 +15,7 @@ import com.r3corda.node.services.wallet.NodeWalletService import org.junit.After import org.junit.Before import org.junit.Test +import org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.* import kotlin.test.assertEquals import kotlin.test.assertNull @@ -96,4 +98,62 @@ class WalletWithCashTest { // TODO: Flesh out these tests as needed. } + + + @Test + fun branchingLinearStatesFails() { + val (wallet, services) = make() + + val freshKey = services.keyManagementService.freshKey() + + val thread = SecureHash.sha256("thread") + + // Issue a linear state + val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearState(thread = thread, participants = listOf(freshKey.public))) + signWith(freshKey) + }.toSignedTransaction() + + wallet.notify(dummyIssue.tx) + assertEquals(1, wallet.currentWallet.states.size) + + // Issue another linear state of the same thread (nonce different) + val dummyIssue2 = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearState(thread = thread, participants = listOf(freshKey.public))) + signWith(freshKey) + }.toSignedTransaction() + + assertThatThrownBy { + wallet.notify(dummyIssue2.tx) + } + assertEquals(1, wallet.currentWallet.states.size) + } + + @Test + fun sequencingLinearStatesWorks() { + val (wallet, services) = make() + + val freshKey = services.keyManagementService.freshKey() + + val thread = SecureHash.sha256("thread") + + // Issue a linear state + val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearState(thread = thread, participants = listOf(freshKey.public))) + signWith(freshKey) + }.toSignedTransaction() + + wallet.notify(dummyIssue.tx) + assertEquals(1, wallet.currentWallet.states.size) + + // Move the same state + val dummyMove = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + addOutputState(DummyLinearState(thread = thread, participants = listOf(freshKey.public))) + addInputState(dummyIssue.tx.outRef(0)) + signWith(DUMMY_NOTARY_KEY) + }.toSignedTransaction() + + wallet.notify(dummyMove.tx) + assertEquals(1, wallet.currentWallet.states.size) + } }