From f09bff9c0fb8ec393197106f32426e04d08b7bc8 Mon Sep 17 00:00:00 2001 From: josecoll Date: Thu, 29 Aug 2019 09:55:21 +0100 Subject: [PATCH] Make concurrent updates to contractStateTypeMappings thread safe. (#5410) --- .../node/services/vault/NodeVaultService.kt | 12 +++++-- .../services/vault/NodeVaultServiceTest.kt | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 5480c0264b..309553bb17 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -28,6 +28,8 @@ import java.security.PublicKey import java.time.Clock import java.time.Instant import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CopyOnWriteArraySet import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder import javax.persistence.criteria.CriteriaUpdate @@ -91,7 +93,8 @@ class NodeVaultService( * Maintain a list of contract state interfaces to concrete types stored in the vault * for usage in generic queries of type queryBy or queryBy> */ - private val contractStateTypeMappings = mutableMapOf>().toSynchronised() + @VisibleForTesting + internal val contractStateTypeMappings = ConcurrentHashMap>() override fun start() { bootstrapContractStateTypes() @@ -103,7 +106,7 @@ class NodeVaultService( if (!seen) { val contractTypes = deriveContractTypes(concreteType) contractTypes.map { - val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } + val contractStateType = contractStateTypeMappings.getOrPut(it.name) { CopyOnWriteArraySet() } contractStateType.add(concreteType.name) } } @@ -203,6 +206,9 @@ class NodeVaultService( override val updates: Observable> get() = mutex.locked { _updatesInDbTx } + @VisibleForTesting + internal val publishUpdates get() = mutex.locked { updatesPublisher } + /** Groups adjacent transactions into batches to generate separate net updates per transaction type. */ override fun notifyAll(statesToRecord: StatesToRecord, txns: Iterable) { if (statesToRecord == StatesToRecord.NONE || !txns.any()) return @@ -716,7 +722,7 @@ class NodeVaultService( concreteType?.let { val contractTypes = deriveContractTypes(it) contractTypes.map { - val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } + val contractStateType = contractStateTypeMappings.getOrPut(it.name) { CopyOnWriteArraySet() } contractStateType.add(concreteType.name) } } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt index a9c50a3ede..1da96f4213 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt @@ -4,6 +4,7 @@ import co.paralleluniverse.fibers.Suspendable import com.nhaarman.mockito_kotlin.* import net.corda.core.contracts.* import net.corda.core.crypto.NullKeys +import net.corda.core.crypto.SecureHash import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.* import net.corda.core.internal.NotaryChangeTransactionBuilder @@ -19,6 +20,7 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.toNonEmptySet import net.corda.finance.* import net.corda.finance.contracts.asset.Cash @@ -945,4 +947,36 @@ class NodeVaultServiceTest { assertTrue(it) } } + + @Test + fun `test concurrent update of contract state type mappings`() { + // no registered contract state types at start-up. + assertEquals(0, vaultService.contractStateTypeMappings.size) + + fun makeCash(amount: Amount, issuer: AbstractParty, depositRef: Byte = 1) = + StateAndRef( + TransactionState(Cash.State(amount `issued by` issuer.ref(depositRef), identity.party), Cash.PROGRAM_ID, DUMMY_NOTARY, constraint = AlwaysAcceptAttachmentConstraint), + StateRef(SecureHash.randomSHA256(), Random().nextInt(32)) + ) + + val cashIssued = setOf>(makeCash(100.DOLLARS, dummyCashIssuer.party)) + val cashUpdate = Vault.Update(emptySet(), cashIssued) + + val service = Executors.newFixedThreadPool(10) + (1..100).map { + service.submit { + database.transaction { + vaultService.publishUpdates.onNext(cashUpdate) + } + } + }.forEach { it.getOrThrow() } + + vaultService.contractStateTypeMappings.forEach { + println("${it.key} = ${it.value}") + } + // Cash.State and its superclasses and interfaces: FungibleAsset, FungibleState, OwnableState, QueryableState + assertEquals(4, vaultService.contractStateTypeMappings.size) + + service.shutdown() + } }