diff --git a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt index d205e2be15..48667faa4e 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt @@ -60,81 +60,32 @@ class PersistentMap( override val size get() = cache.estimatedSize().toInt() - private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K, V) -> V?, replace: (K, V) -> Unit): Boolean { + private tailrec fun set(key: K, value: V): Boolean { var insertionAttempt = false var isUnique = true val existingInCache = cache.get(key) { // Thread safe, if multiple threads may wait until the first one has loaded. insertionAttempt = true - // Value wasn't in the cache and wasn't in DB (because the cache is unbound). - // Store the value, depending on store implementation this may replace existing entry in DB. - store(key, value) + // Value wasn't in the cache and wasn't in DB (because the cache is unbound) so save it. + merge(key, value) Optional.of(value) }!! if (!insertionAttempt) { if (existingInCache.isPresent) { - // Key already exists in cache, store the new value in the DB (depends on tore implementation) and refresh cache. + // Key already exists in cache, store the new value in the DB and refresh cache. isUnique = false - replace(key, value) + replaceValue(key, value) } else { // This happens when the key was queried before with no value associated. We invalidate the cached null // value and recursively call set again. This is to avoid race conditions where another thread queries after // the invalidate but before the set. cache.invalidate(key) - return set(key, value, logWarning, store, replace) + return set(key, value) } } - if (logWarning && !isUnique) { - log.warn("Double insert in ${this.javaClass.name} for entity class $persistentEntityClass key $key, not inserting the second time") - } return isUnique } - /** - * Associates the specified value with the specified key in this map and persists it. - * WARNING! If the map previously contained a mapping for the key, the behaviour is unpredictable and may throw an error from the underlying storage. - */ - operator fun set(key: K, value: V) = - set(key, value, - logWarning = false, - store = { k: K, v: V -> - currentDBSession().save(toPersistentEntity(k, v)) - null - }, - replace = { _: K, _: V -> Unit } - ) - - /** - * Associates the specified value with the specified key in this map and persists it. - * WARNING! If the map previously contained a mapping for the key, the old value is not replaced. - * @return true if added key was unique, otherwise false - */ - fun addWithDuplicatesAllowed(key: K, value: V) = - set(key, value, - store = { k, v -> - val session = currentDBSession() - val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(k)) - if (existingEntry == null) { - session.save(toPersistentEntity(k, v)) - null - } else { - fromPersistentEntity(existingEntry).second - } - }, - replace = { _: K, _: V -> Unit } - ) - - /** - * Associates the specified value with the specified key in this map and persists it. - * @return true if added key was unique, otherwise false - */ - private fun addWithDuplicatesReplaced(key: K, value: V) = - set(key, value, - logWarning = false, - store = { k: K, v: V -> merge(k, v) }, - replace = { k: K, v: V -> replaceValue(k, v) } - ) - private fun replaceValue(key: K, value: V) { synchronized(this) { merge(key, value) @@ -248,9 +199,13 @@ class PersistentMap( } } + /** + * Associates the specified value with the specified key in this map and persists it. + * @return true if added key was unique, otherwise false + */ override fun put(key: K, value: V): V? { val old = cache.get(key) - addWithDuplicatesReplaced(key, value) + set(key, value) return old!!.orElse(null) } diff --git a/node/src/test/kotlin/net/corda/node/utilities/PersistentMapTests.kt b/node/src/test/kotlin/net/corda/node/utilities/PersistentMapTests.kt new file mode 100644 index 0000000000..7745f1a63b --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/utilities/PersistentMapTests.kt @@ -0,0 +1,157 @@ +package net.corda.node.utilities + +import net.corda.core.crypto.SecureHash +import net.corda.node.internal.configureDatabase +import net.corda.node.services.upgrade.ContractUpgradeServiceImpl +import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.internal.rigorousMock +import net.corda.testing.node.MockServices +import org.junit.Test +import kotlin.test.assertEquals + +class PersistentMapTests { + private val databaseConfig = DatabaseConfig() + private val database get() = configureDatabase(dataSourceProps, databaseConfig, rigorousMock()) + private val dataSourceProps = MockServices.makeTestDataSourceProperties() + + //create a test map using an existing db table + private fun createTestMap(): PersistentMap { + return PersistentMap( + toPersistentEntityKey = { it }, + fromPersistentEntity = { Pair(it.stateRef, it.upgradedContractClassName) }, + toPersistentEntity = { key: String, value: String -> + ContractUpgradeServiceImpl.DBContractUpgrade().apply { + stateRef = key + upgradedContractClassName = value + } + }, + persistentEntityClass = ContractUpgradeServiceImpl.DBContractUpgrade::class.java + ) + } + + @Test + fun `make sure persistence works`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map.put(testHash, "test") + assertEquals(map[testHash], "test") + } + + database.transaction { + val reloadedMap = createTestMap() + assertEquals("test", reloadedMap[testHash]) + } + } + + @Test + fun `make sure persistence works using assignment operator`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map[testHash] = "test" + assertEquals("test", map[testHash]) + } + + database.transaction { + val reloadedMap = createTestMap() + assertEquals("test", reloadedMap[testHash]) + } + } + + @Test + fun `make sure updating works`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map.put(testHash, "test") + + map.put(testHash, "updated") + assertEquals("updated", map[testHash]) + } + + database.transaction { + val reloadedMap = createTestMap() + assertEquals("updated", reloadedMap[testHash]) + } + } + + @Test + fun `make sure updating works using assignment operator`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map[testHash] = "test" + + map[testHash] = "updated" + assertEquals("updated", map[testHash]) + } + + database.transaction { + val reloadedMap = createTestMap() + assertEquals("updated", reloadedMap[testHash]) + } + } + + @Test + fun `make sure removal works`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map[testHash] = "test" + } + + database.transaction { + val reloadedMap = createTestMap() + //check that the item was persisted + assertEquals("test", reloadedMap[testHash]) + + reloadedMap.remove(testHash) + //check that the item was removed in the version of the map + assertEquals(null, reloadedMap[testHash]) + } + + database.transaction { + val reloadedMap = createTestMap() + //check that the item was removed from the persistent store + assertEquals(null, reloadedMap[testHash]) + } + } + + @Test + fun `make sure persistence works against base class`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() + map.put(testHash, "test") + assertEquals(map[testHash], "test") + } + + database.transaction { + val reloadedMap = createTestMap() + assertEquals("test", reloadedMap[testHash]) + } + } + + @Test + fun `make sure persistence works using assignment operator base class`() { + val testHash = SecureHash.randomSHA256().toString() + + database.transaction { + val map = createTestMap() as MutableMap + map[testHash] = "test" + assertEquals("test", map[testHash]) + } + + database.transaction { + val reloadedMap = createTestMap() as MutableMap + assertEquals("test", reloadedMap[testHash]) + } + } +} \ No newline at end of file