ENT-1383 Memory weight based transaction cache (#2355)

* ENT-1383 Make the transaction cache in DBTransactionStorage memory-weight based (rather than count based) so large transactions can no longer use an undue amount of memory.

* Code review: formatting and legibility

* Fix stupid type cast error

* More formatting
This commit is contained in:
Christian Sailer 2018-01-15 13:48:55 +00:00 committed by GitHub
parent 591e37adb3
commit df195b20bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 133 additions and 32 deletions

View File

@ -214,7 +214,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
val networkMapCache = NetworkMapCacheImpl(PersistentNetworkMapCache(database, networkParameters.notaries), identityService)
val (keyPairs, info) = initNodeInfo(networkMapCache, identity, identityKeyPair)
identityService.loadIdentities(info.legalIdentitiesAndCerts)
val transactionStorage = makeTransactionStorage(database)
val transactionStorage = makeTransactionStorage(database, configuration.transactionCacheSizeBytes)
val nodeServices = makeServices(keyPairs, schemaService, transactionStorage, database, info, identityService, networkMapCache)
val notaryService = makeNotaryService(nodeServices, database)
val smm = makeStateMachineManager(database)
@ -559,7 +559,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
return tokenizableServices
}
protected open fun makeTransactionStorage(database: CordaPersistence): WritableTransactionStorage = DBTransactionStorage()
protected open fun makeTransactionStorage(database: CordaPersistence, transactionCacheSizeBytes: Long): WritableTransactionStorage = DBTransactionStorage(transactionCacheSizeBytes)
private fun makeVaultObservers(schedulerService: SchedulerService, hibernateConfig: HibernateConfiguration, smm: StateMachineManager, schemaService: SchemaService, flowLogicRefFactory: FlowLogicRefFactory) {
VaultSoftLockManager.install(services.vaultService, smm)
ScheduledActivityObserver.install(services.vaultService, schedulerService, flowLogicRefFactory)

View File

@ -14,6 +14,8 @@ import java.net.URL
import java.nio.file.Path
import java.util.*
val Int.MB: Long get() = this * 1024L * 1024L
interface NodeConfiguration : NodeSSLConfiguration {
// myLegalName should be only used in the initial network registration, we should use the name from the certificate instead of this.
// TODO: Remove this so we don't accidentally use this identity in the code?
@ -43,6 +45,17 @@ interface NodeConfiguration : NodeSSLConfiguration {
val sshd: SSHDConfiguration?
val database: DatabaseConfig
val useAMQPBridges: Boolean get() = true
val transactionCacheSizeBytes: Long get() = defaultTransactionCacheSize
companion object {
// default to at least 8MB and a bit extra for larger heap sizes
val defaultTransactionCacheSize: Long = 8.MB + getAdditionalCacheMemory()
// add 5% of any heapsize over 300MB to the default transaction cache size
private fun getAdditionalCacheMemory(): Long {
return Math.max((Runtime.getRuntime().maxMemory() - 300.MB) / 20, 0)
}
}
}
data class DevModeOptions(val disableCheckpointChecker: Boolean = false)
@ -118,7 +131,8 @@ data class NodeConfigurationImpl(
override val additionalNodeInfoPollingFrequencyMsec: Long = 5.seconds.toMillis(),
override val sshd: SSHDConfiguration? = null,
override val database: DatabaseConfig = DatabaseConfig(initialiseSchema = devMode, exportHibernateJMXStatistics = devMode),
override val useAMQPBridges: Boolean = true
override val useAMQPBridges: Boolean = true,
override val transactionCacheSizeBytes: Long = NodeConfiguration.defaultTransactionCacheSize
) : NodeConfiguration {
override val exportJMXto: String get() = "http"

View File

@ -1,10 +1,12 @@
package net.corda.node.services.persistence
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.ThreadBox
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.utilities.*
@ -13,9 +15,15 @@ import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import rx.Observable
import rx.subjects.PublishSubject
import java.util.*
import javax.persistence.*
class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() {
// cache value type to just store the immutable bits of a signed transaction plus conversion helpers
typealias TxCacheValue = Pair<SerializedBytes<CoreTransaction>, List<TransactionSignature>>
fun TxCacheValue.toSignedTx() = SignedTransaction(this.first, this.second)
fun SignedTransaction.toTxCacheValue() = TxCacheValue(this.txBits, this.sigs)
class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, SingletonSerializeAsToken() {
@Entity
@Table(name = "${NODE_DATABASE_PREFIX}transactions")
@ -30,40 +38,63 @@ class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsTok
)
private companion object {
fun createTransactionsMap(): AppendOnlyPersistentMap<SecureHash, SignedTransaction, DBTransaction, String> {
return AppendOnlyPersistentMap(
fun createTransactionsMap(maxSizeInBytes: Long)
: AppendOnlyPersistentMapBase<SecureHash, TxCacheValue, DBTransaction, String> {
return WeightBasedAppendOnlyPersistentMap<SecureHash, TxCacheValue, DBTransaction, String>(
toPersistentEntityKey = { it.toString() },
fromPersistentEntity = {
Pair(SecureHash.parse(it.txId),
it.transaction.deserialize<SignedTransaction>(context = SerializationDefaults.STORAGE_CONTEXT))
it.transaction.deserialize<SignedTransaction>(context = SerializationDefaults.STORAGE_CONTEXT)
.toTxCacheValue())
},
toPersistentEntity = { key: SecureHash, value: SignedTransaction ->
toPersistentEntity = { key: SecureHash, value: TxCacheValue ->
DBTransaction().apply {
txId = key.toString()
transaction = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
transaction = value.toSignedTx().
serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
}
},
persistentEntityClass = DBTransaction::class.java
persistentEntityClass = DBTransaction::class.java,
maxWeight = maxSizeInBytes,
weighingFunc = { hash, tx -> hash.size + weighTx(tx) }
)
}
// Rough estimate for the average of a public key and the transaction metadata - hard to get exact figures here,
// as public keys can vary in size a lot, and if someone else is holding a reference to the key, it won't add
// to the memory pressure at all here.
private const val transactionSignatureOverheadEstimate = 1024
private fun weighTx(tx: Optional<TxCacheValue>): Int {
if (!tx.isPresent) {
return 0
}
val actTx = tx.get()
return actTx.second.sumBy { it.size + transactionSignatureOverheadEstimate } + actTx.first.size
}
}
private val txStorage = createTransactionsMap()
private val txStorage = ThreadBox(createTransactionsMap(cacheSizeBytes))
override fun addTransaction(transaction: SignedTransaction): Boolean =
txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply {
txStorage.locked {
addWithDuplicatesAllowed(transaction.id, transaction.toTxCacheValue()).apply {
updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction)
}
}
override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id]
override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage.content[id]?.toSignedTx()
private val updatesPublisher = PublishSubject.create<SignedTransaction>().toSerialized()
override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction()
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> =
DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return txStorage.locked {
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
}
}
@VisibleForTesting
val transactions: Iterable<SignedTransaction>
get() = txStorage.allPersisted().map { it.second }.toList()
get() = txStorage.content.allPersisted().map { it.second.toSignedTx() }.toList()
}

View File

@ -1,5 +1,7 @@
package net.corda.node.utilities
import com.google.common.cache.LoadingCache
import com.google.common.cache.Weigher
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.persistence.currentDBSession
import java.util.*
@ -10,23 +12,18 @@ import java.util.*
* behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so
* ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY
*/
class AppendOnlyPersistentMap<K, V, E, out EK>(
abstract class AppendOnlyPersistentMapBase<K, V, E, out EK>(
val toPersistentEntityKey: (K) -> EK,
val fromPersistentEntity: (E) -> Pair<K, V>,
val toPersistentEntity: (key: K, value: V) -> E,
val persistentEntityClass: Class<E>,
cacheBound: Long = 1024
) { //TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size
val persistentEntityClass: Class<E>
) {
private companion object {
private val log = contextLogger()
}
private val cache = NonInvalidatingCache<K, Optional<V>>(
bound = cacheBound,
concurrencyLevel = 8,
loadFunction = { key -> Optional.ofNullable(loadValue(key)) }
)
abstract protected val cache: LoadingCache<K, Optional<V>>
/**
* Returns the value associated with the key, first loading that value from the storage if necessary.
@ -116,7 +113,7 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
}
}
private fun loadValue(key: K): V? {
protected fun loadValue(key: K): V? {
val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
}
@ -135,3 +132,45 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
cache.invalidateAll()
}
}
class AppendOnlyPersistentMap<K, V, E, out EK>(
toPersistentEntityKey: (K) -> EK,
fromPersistentEntity: (E) -> Pair<K, V>,
toPersistentEntity: (key: K, value: V) -> E,
persistentEntityClass: Class<E>,
cacheBound: Long = 1024
) : AppendOnlyPersistentMapBase<K, V, E, EK>(
toPersistentEntityKey,
fromPersistentEntity,
toPersistentEntity,
persistentEntityClass) {
//TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size
override val cache = NonInvalidatingCache<K, Optional<V>>(
bound = cacheBound,
concurrencyLevel = 8,
loadFunction = { key -> Optional.ofNullable(loadValue(key)) })
}
class WeightBasedAppendOnlyPersistentMap<K, V, E, out EK>(
toPersistentEntityKey: (K) -> EK,
fromPersistentEntity: (E) -> Pair<K, V>,
toPersistentEntity: (key: K, value: V) -> E,
persistentEntityClass: Class<E>,
maxWeight: Long,
weighingFunc: (K, Optional<V>) -> Int
) : AppendOnlyPersistentMapBase<K, V, E, EK>(
toPersistentEntityKey,
fromPersistentEntity,
toPersistentEntity,
persistentEntityClass) {
override val cache = NonInvalidatingWeightBasedCache<K, Optional<V>>(
maxWeight = maxWeight,
concurrencyLevel = 8,
weigher = object : Weigher<K, Optional<V>> {
override fun weigh(key: K, value: Optional<V>): Int {
return weighingFunc(key, value)
}
},
loadFunction = { key -> Optional.ofNullable(loadValue(key)) }
)
}

View File

@ -3,6 +3,7 @@ package net.corda.node.utilities
import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import com.google.common.cache.LoadingCache
import com.google.common.cache.Weigher
import com.google.common.util.concurrent.ListenableFuture
@ -21,7 +22,7 @@ class NonInvalidatingCache<K, V> private constructor(
}
// TODO look into overriding loadAll() if we ever use it
private class NonInvalidatingCacheLoader<K, V>(val loadFunction: (K) -> V) : CacheLoader<K, V>() {
class NonInvalidatingCacheLoader<K, V>(val loadFunction: (K) -> V) : CacheLoader<K, V>() {
override fun reload(key: K, oldValue: V): ListenableFuture<V> {
throw IllegalStateException("Non invalidating cache refreshed")
}
@ -29,3 +30,18 @@ class NonInvalidatingCache<K, V> private constructor(
override fun load(key: K) = loadFunction(key)
}
}
class NonInvalidatingWeightBasedCache<K, V> private constructor(
val cache: LoadingCache<K, V>
) : LoadingCache<K, V> by cache {
constructor (maxWeight: Long, concurrencyLevel: Int, weigher: Weigher<K, V>, loadFunction: (K) -> V) :
this(buildCache(maxWeight, concurrencyLevel, weigher, loadFunction))
private companion object {
private fun <K, V> buildCache(maxWeight: Long, concurrencyLevel: Int, weigher: Weigher<K, V>, loadFunction: (K) -> V): LoadingCache<K, V> {
val builder = CacheBuilder.newBuilder().maximumWeight(maxWeight).weigher(weigher).concurrencyLevel(concurrencyLevel)
return builder.build(NonInvalidatingCache.NonInvalidatingCacheLoader(loadFunction))
}
}
}

View File

@ -314,8 +314,8 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
return mockNet.createNode(MockNodeParameters(legalName = name), nodeFactory = { args ->
object : MockNetwork.MockNode(args) {
// That constructs a recording tx storage
override fun makeTransactionStorage(database: CordaPersistence): WritableTransactionStorage {
return RecordingTransactionStorage(database, super.makeTransactionStorage(database))
override fun makeTransactionStorage(database: CordaPersistence, transactionCacheSizeBytes: Long): WritableTransactionStorage {
return RecordingTransactionStorage(database, super.makeTransactionStorage(database, transactionCacheSizeBytes))
}
}
})

View File

@ -10,6 +10,7 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.internal.configureDatabase
import net.corda.node.services.config.NodeConfiguration
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.*
@ -173,7 +174,7 @@ class DBTransactionStorageTests {
private fun newTransactionStorage() {
database.transaction {
transactionStorage = DBTransactionStorage()
transactionStorage = DBTransactionStorage(NodeConfiguration.defaultTransactionCacheSize)
}
}