diff --git a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt index 8c1ec72b30..37beeab2bf 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/InMemoryWalletService.kt @@ -22,6 +22,8 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe open class InMemoryWalletService(private val services: ServiceHub) : SingletonSerializeAsToken(), WalletService { + class ClashingThreads(threads: Set) : + Exception("There are multiple linear states pointing to the same thread. The clashing thread(s): $threads") private val log = loggerFor() // Variables inside InnerState are protected with a lock by the ThreadBox and aren't in scope unless you're @@ -44,7 +46,7 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe * Returns a snapshot of the heads of LinearStates */ override val linearHeads: Map> - get() = mutex.locked { wallet }.let { wallet -> + get() = currentWallet.let { wallet -> wallet.states.filterStatesOfType().associateBy { it.state.data.thread }.mapValues { it.value } } @@ -74,10 +76,17 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe val combinedDelta = delta + walletAndDelta.second Pair(wallet, combinedDelta) } + + val clashingThreads = walletAndNetDelta.first.clashingThreads + if (!clashingThreads.isEmpty()) { + throw ClashingThreads(clashingThreads) + } + wallet = walletAndNetDelta.first netDelta = walletAndNetDelta.second return@locked wallet } + if (netDelta != Wallet.NoUpdate) { _updatesPublisher.onNext(netDelta) } @@ -120,4 +129,23 @@ open class InMemoryWalletService(private val services: ServiceHub) : SingletonSe return Pair(Wallet(newStates), change) } -} \ No newline at end of file + + companion object { + + // Returns the set of LinearState threads that clash in the wallet + val Wallet.clashingThreads: Set get() { + val clashingThreads = HashSet() + val threadsSeen = HashSet() + for (linearState in states.filterStatesOfType()) { + val thread = linearState.state.data.thread + if (threadsSeen.contains(thread)) { + clashingThreads.add(thread) + } else { + threadsSeen.add(thread) + } + } + return clashingThreads + } + + } +}