From 0131163db046ed29b610763839bd409e7621d605 Mon Sep 17 00:00:00 2001
From: Chris Cochrane <78791827+chriscochrane@users.noreply.github.com>
Date: Thu, 24 Nov 2022 18:13:35 +0000
Subject: [PATCH] ENT-8814 - back-fit changes from Enterprise to OS (#7272)

---
 .../persistence/DBTransactionStorage.kt       | 12 ++--
 .../node/utilities/AppendOnlyPersistentMap.kt | 64 ++++++++++++++-----
 2 files changed, 54 insertions(+), 22 deletions(-)

diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt
index aeeea1dba8..24046f2941 100644
--- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt
+++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt
@@ -101,7 +101,9 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory:
         // Rough estimate for the average of a public key and the transaction metadata - hard to get exact figures here,
         // as public keys can vary in size a lot, and if someone else is holding a reference to the key, it won't add
         // to the memory pressure at all here.
-        private const val transactionSignatureOverheadEstimate = 1024
+        private const val TRANSACTION_SIGNATURE_OVERHEAD_BYTES = 1024
+        private const val TXCACHEVALUE_OVERHEAD_BYTES = 80
+        private const val SECUREHASH_OVERHEAD_BYTES = 24
 
         private val logger = contextLogger()
 
@@ -134,13 +136,13 @@ class DBTransactionStorage(private val database: CordaPersistence, cacheFactory:
                         )
                     },
                     persistentEntityClass = DBTransaction::class.java,
-                    weighingFunc = { hash, tx -> hash.size + weighTx(tx) }
+                    weighingFunc = { hash, tx -> SECUREHASH_OVERHEAD_BYTES + hash.size + weighTx(tx) }
             )
         }
 
-        private fun weighTx(tx: AppendOnlyPersistentMapBase.Transactional<TxCacheValue>): Int {
-            val actTx = tx.peekableValue ?: return 0
-            return actTx.sigs.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.txBits.size
+        private fun weighTx(actTx: TxCacheValue?): Int {
+            if (actTx == null) return 0
+            return TXCACHEVALUE_OVERHEAD_BYTES + actTx.sigs.sumBy { it.size + TRANSACTION_SIGNATURE_OVERHEAD_BYTES } + actTx.txBits.size
         }
 
         private val log = contextLogger()
diff --git a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt
index f45ddbb7cf..570172fa06 100644
--- a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt
+++ b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt
@@ -32,8 +32,10 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
         private val log = contextLogger()
     }
 
+    protected class PendingKeyValue(val transactions: MutableSet<DatabaseTransaction>, val estimatedSize: Int)
+
     protected abstract val cache: LoadingCache<K, Transactional<V>>
-    protected val pendingKeys = ConcurrentHashMap<K, MutableSet<DatabaseTransaction>>()
+    protected val pendingKeys = ConcurrentHashMap<K, PendingKeyValue>()
 
     /**
      * Returns the value associated with the key, first loading that value from the storage if necessary.
@@ -85,7 +87,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
                     // for cases where the value passed to set differs from that in the cache, but an update function has decided that this
                     // differing value should not be written to the database.
                     if (wasWritten) {
-                        Transactional.InFlight(this, key, _readerValueLoader = { loadValue(key) }).apply { alsoWrite(value) }
+                        Transactional.InFlight(this, key, weight = weight(key, value), _readerValueLoader = { loadValue(key) })
+                                .apply { alsoWrite(value) }
                     } else {
                         oldValueInCache
                     }
@@ -120,7 +123,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
             Transactional.Committed(oldValue)
         } else {
             // Some database transactions, including us, writing, with readers seeing whatever is in the database and writers seeing the (in memory) value.
-            Transactional.InFlight(this, key, _readerValueLoader = { loadValue(key) }).apply { alsoWrite(value) }
+            Transactional.InFlight(this, key, weight = weight(key, value), _readerValueLoader = { loadValue(key) })
+                    .apply { alsoWrite(value) }
         }
     }
 
@@ -214,11 +218,12 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
 
     protected fun transactionalLoadValue(key: K): Transactional<V> {
         // This gets called if a value is read and the cache has no Transactional for this key yet.
-        return if (anyoneWriting(key)) {
+        val estimatedSize = anyoneWriting(key)
+        return if (estimatedSize != -1) {
             // If someone is writing (but not us)
             // For those not writing, they need to re-load the value from the database (which their database transaction MIGHT see).
             // For those writing, they need to re-load the value from the database (which their database transaction CAN see).
-            Transactional.InFlight(this, key, { loadValue(key) }, { loadValue(key)!! })
+            Transactional.InFlight(this, key, estimatedSize, { loadValue(key) }, { loadValue(key)!! })
         } else {
             // If no one is writing, then the value may or may not exist in the database.
             Transactional.Unknown(this, key) { loadValue(key) }
@@ -240,21 +245,24 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
     }
 
     // Helpers to know if transaction(s) are currently writing the given key.
-    private fun weAreWriting(key: K): Boolean = pendingKeys[key]?.contains(contextTransaction) ?: false
+    private fun weAreWriting(key: K): Boolean = pendingKeys[key]?.transactions?.contains(contextTransaction) ?: false
 
-    private fun anyoneWriting(key: K): Boolean = pendingKeys[key]?.isNotEmpty() ?: false
+    private fun anyoneWriting(key: K): Int = pendingKeys[key]?.estimatedSize ?: -1
+
+    protected open fun weight(key: K, value: V): Int = 1
 
     // Indicate this database transaction is a writer of this key.
-    private fun addPendingKey(key: K, databaseTransaction: DatabaseTransaction): Boolean {
+    private fun addPendingKey(key: K, databaseTransaction: DatabaseTransaction, estimatedSize: Int): Boolean {
         var added = true
-        pendingKeys.compute(key) { _, oldSet ->
+        pendingKeys.compute(key) { _, value: PendingKeyValue? ->
+            val oldSet = value?.transactions
             if (oldSet == null) {
                 val newSet = HashSet<DatabaseTransaction>(0)
                 newSet += databaseTransaction
-                newSet
+                PendingKeyValue(newSet, estimatedSize)
             } else {
                 added = oldSet.add(databaseTransaction)
-                oldSet
+                value
             }
         }
         return added
@@ -262,12 +270,13 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
 
     // Remove this database transaction as a writer of this key, because the transaction committed or rolled back.
     private fun removePendingKey(key: K, databaseTransaction: DatabaseTransaction) {
-        pendingKeys.compute(key) { _, oldSet ->
+        pendingKeys.compute(key) { _, value: PendingKeyValue? ->
+            val oldSet = value?.transactions
             if (oldSet == null) {
-                oldSet
+                null
             } else {
                 oldSet -= databaseTransaction
-                if (oldSet.size == 0) null else oldSet
+                if (oldSet.size == 0) null else value
             }
         }
     }
@@ -278,10 +287,12 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
      * There are 3 states.  Globally missing, globally visible, and being written in a transaction somewhere now or in
      * the past (and it rolled back).
      */
+    @Suppress("MagicNumber")
     sealed class Transactional<T> {
         abstract val value: T
         abstract val isPresent: Boolean
         abstract val peekableValue: T?
+        abstract val shallowSize: Int
 
         fun orElse(alt: T?) = if (isPresent) value else alt
 
@@ -291,6 +302,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
                 get() = true
             override val peekableValue: T?
                 get() = value
+            override val shallowSize: Int
+                get() = 48
         }
 
         // No one can see it.
@@ -301,6 +314,8 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
                 get() = false
             override val peekableValue: T?
                 get() = null
+            override val shallowSize: Int
+                get() = 16
         }
 
         // No one is writing, but we haven't looked in the database yet.  This can only be when there are no writers.
@@ -323,12 +338,15 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
             }
             val isResolved: Boolean get() = valueWithoutIsolationDelegate.isInitialized()
             override val peekableValue: T? get() = if (isResolved && isPresent) value else null
+            override val shallowSize: Int
+                get() = 128
         }
 
         // Written in a transaction (uncommitted) somewhere, but there's a small window when this might be seen after commit,
         // hence the committed flag.
         class InFlight<K, T>(private val map: AppendOnlyPersistentMapBase<K, T, *, *>,
                              private val key: K,
+                             val weight: Int,
                              private val _readerValueLoader: () -> T?,
                              private val _writerValueLoader: () -> T = { throw IllegalAccessException("No value loader provided") }) : Transactional<T>() {
 
@@ -352,7 +370,7 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
                 val tx = contextTransaction
                 val strongKey = key
                 val strongMap = map
-                if (map.addPendingKey(key, tx)) {
+                if (map.addPendingKey(key, tx, weight)) {
                     // If the transaction commits, update cache to make globally visible if we're first for this key,
                     // and then stop saying the transaction is writing the key.
                     tx.onCommit {
@@ -414,6 +432,9 @@ abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
             // The value from the perspective of the eviction algorithm of the cache.  i.e. we want to reveal memory footprint to it etc.
             override val peekableValue: T?
                 get() = if (writerValueLoader.get() != _writerValueLoader) writerValueLoader.get()() else if (readerValueLoader.get() != _readerValueLoader) readerValueLoader.get()() else null
+
+            override val shallowSize: Int
+                get() = 256
         }
     }
 }
@@ -445,15 +466,24 @@ class WeightBasedAppendOnlyPersistentMap<K, V, E, out EK>(
         fromPersistentEntity: (E) -> Pair<K, V>,
         toPersistentEntity: (key: K, value: V) -> E,
         persistentEntityClass: Class<E>,
-        weighingFunc: (K, Transactional<V>) -> Int
+        private val weighingFunc: (K, V?) -> Int
 ) : AppendOnlyPersistentMapBase<K, V, E, EK>(
         toPersistentEntityKey,
         fromPersistentEntity,
         toPersistentEntity,
         persistentEntityClass) {
+
+    override fun weight(key: K, value: V): Int = weighingFunc(key, value)
+
     override val cache = NonInvalidatingWeightBasedCache(
             cacheFactory = cacheFactory,
             name = name,
-            weigher = Weigher { key, value -> weighingFunc(key, value) },
+            weigher = Weigher { key, value: Transactional<V> ->
+                value.shallowSize + if (value is Transactional.InFlight<*, *>) {
+                    value.weight * 2
+                } else {
+                    weighingFunc(key, value.peekableValue)
+                }
+            },
             loadFunction = { key: K -> transactionalLoadValue(key) })
 }