diff --git a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt index f8f340f817..6607fcc92a 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt @@ -1,14 +1,19 @@ package net.corda.node.services.network -import net.corda.core.internal.ThreadBox +import net.corda.core.crypto.toBase58String import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.ThreadBox import net.corda.core.messaging.SingleMessageRecipient +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize import net.corda.node.services.api.ServiceHubInternal import net.corda.node.utilities.* -import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement -import java.util.Collections.synchronizedMap +import java.io.ByteArrayInputStream +import java.security.cert.CertificateFactory +import javax.persistence.* +import java.io.Serializable +import java.util.* /** * A network map service backed by a database to survive restarts of the node hosting it. @@ -22,30 +27,97 @@ class PersistentNetworkMapService(services: ServiceHubInternal, minimumPlatformV : AbstractNetworkMapService(services, minimumPlatformVersion) { // Only the node_party_path column is needed to reconstruct a PartyAndCertificate but we have the others for human readability - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}network_map_nodes") { - val nodeParty = partyAndCertificate("node_party_name", "node_party_key", "node_party_certificate", "node_party_path") - val registrationInfo = blob("node_registration_info") + @Entity + @Table(name = "${NODE_DATABASE_PREFIX}network_map_nodes") + class NetworkNode( + @EmbeddedId + @Column + var nodeParty: NodeParty = NodeParty(), + + @Lob + @Column + var registrationInfo: ByteArray = ByteArray(0) + ) + + @Embeddable + data class NodeParty( + @Column(name = "node_party_name") + var name: String = "", + + @Column(name = "node_party_key", length = 4096) + var owningKey: String = "", // PublicKey + + @Column(name = "node_party_certificate", length = 4096) + var certificate: ByteArray = ByteArray(0), + + @Column(name = "node_party_path", length = 4096) + var certPath: ByteArray = ByteArray(0) + ): Serializable + + private companion object { + private val factory = CertificateFactory.getInstance("X.509") + + fun createNetworkNodesMap(): PersistentMap { + return PersistentMap( + toPersistentEntityKey = { NodeParty( + it.name.toString(), + it.owningKey.toBase58String(), + it.certificate.encoded, + it.certPath.encoded + ) }, + fromPersistentEntity = { + Pair(PartyAndCertificate(factory.generateCertPath(ByteArrayInputStream(it.nodeParty.certPath))), + it.registrationInfo.deserialize(context = SerializationDefaults.STORAGE_CONTEXT)) + }, + toPersistentEntity = { key: PartyAndCertificate, value: NodeRegistrationInfo -> + NetworkNode().apply { + // TODO: We should understand an X500Name database field type, rather than manually doing the conversion ourselves + nodeParty = NodeParty( + key.name.toString(), + key.owningKey.toBase58String(), + key.certificate.encoded, + key.certPath.encoded + ) + registrationInfo = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = NetworkNode::class.java + ) + } + + fun createNetworkSubscribersMap(): PersistentMap { + return PersistentMap( + toPersistentEntityKey = { it.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes}, + fromPersistentEntity = { + Pair(it.key.deserialize(context = SerializationDefaults.STORAGE_CONTEXT), + it.value.deserialize(context = SerializationDefaults.STORAGE_CONTEXT)) + }, + toPersistentEntity = { _key: SingleMessageRecipient, _value: LastAcknowledgeInfo -> + NetworkSubscriber().apply { + key = _key.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + value = _value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = NetworkSubscriber::class.java + ) + } } - override val nodeRegistrations: MutableMap = synchronizedMap(object : AbstractJDBCHashMap(Table, loadOnInit = true) { - // TODO: We should understand an X500Name database field type, rather than manually doing the conversion ourselves - override fun keyFromRow(row: ResultRow): PartyAndCertificate = PartyAndCertificate(row[table.nodeParty.certPath]) + override val nodeRegistrations: MutableMap = + Collections.synchronizedMap(createNetworkNodesMap()) - override fun valueFromRow(row: ResultRow): NodeRegistrationInfo = deserializeFromBlob(row[table.registrationInfo]) + @Entity + @Table(name = "${NODE_DATABASE_PREFIX}network_map_subscribers") + class NetworkSubscriber( + @Id + @Column(length = 4096) + var key: ByteArray = ByteArray(0), - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.nodeParty.name] = entry.key.name.toString() - insert[table.nodeParty.owningKey] = entry.key.owningKey - insert[table.nodeParty.certificate] = entry.key.certificate - insert[table.nodeParty.certPath] = entry.key.certPath - } + @Column(length = 4096) + var value: ByteArray = ByteArray(0) + ) - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.registrationInfo] = serializeToBlob(entry.value, finalizables) - } - }) - - override val subscribers = ThreadBox(JDBCHashMap("${NODE_DATABASE_PREFIX}network_map_subscribers", loadOnInit = true)) + override val subscribers = ThreadBox(createNetworkSubscribersMap()) init { // Initialise the network map version with the current highest persisted version, or zero if there are no entries. diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index e1cd929b1c..81699b2783 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -11,6 +11,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.node.services.api.SchemaService import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.network.PersistentNetworkMapService import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBTransactionMappingStorage import net.corda.node.services.persistence.DBTransactionStorage @@ -37,9 +38,10 @@ class NodeSchemaService(customSchemas: Set = emptySet()) : SchemaS DBTransactionMappingStorage.DBTransactionMapping::class.java, PersistentKeyManagementService.PersistentKey::class.java, PersistentUniquenessProvider.PersistentUniqueness::class.java, - PersistentUniquenessProvider.PersistentUniqueness::class.java, NodeSchedulerService.PersistentScheduledState::class.java, - NodeAttachmentService.DBAttachment::class.java + NodeAttachmentService.DBAttachment::class.java, + PersistentNetworkMapService.NetworkNode::class.java, + PersistentNetworkMapService.NetworkSubscriber::class.java )) // Required schemas are those used by internal Corda services diff --git a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt index 23e02fa5ec..1c90f89856 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt @@ -16,7 +16,7 @@ class PersistentMap ( val fromPersistentEntity: (E) -> Pair, val toPersistentEntity: (key: K, value: V) -> E, val persistentEntityClass: Class -) { +) : MutableMap, AbstractMap() { private companion object { val log = loggerFor>() @@ -26,9 +26,15 @@ class PersistentMap ( concurrencyLevel = 8, loadFunction = { key -> Optional.ofNullable(loadValue(key)) }, removalListener = ExplicitRemoval(toPersistentEntityKey, persistentEntityClass) - ) + ).apply { + //preload to allow all() to take data only from the cache (cache is unbound) + val session = DatabaseTransactionManager.current().session + val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass) + criteriaQuery.select(criteriaQuery.from(persistentEntityClass)) + getAll(session.createQuery(criteriaQuery).resultList.map { e -> fromPersistentEntity(e as E).first }.asIterable()) + } - class ExplicitRemoval(val toPersistentEntityKey: (K) -> EK, val persistentEntityClass: Class): RemovalListener { + class ExplicitRemoval(private val toPersistentEntityKey: (K) -> EK, private val persistentEntityClass: Class): RemovalListener { override fun onRemoval(notification: RemovalNotification?) { when (notification?.cause) { RemovalCause.EXPLICIT -> { @@ -46,38 +52,37 @@ class PersistentMap ( } } - operator fun get(key: K): V? { + override operator fun get(key: K): V? { return cache.get(key).orElse(null) } fun all(): Sequence> { - return cache.asMap().map { entry -> Pair(entry.key as K, entry.value as V) }.asSequence() + return cache.asMap().map { entry -> Pair(entry.key as K, entry.value.get()) }.asSequence() } - private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?): Boolean { + override val size = cache.size().toInt() + + private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?, replace: (K, V) -> Unit) : Boolean { var insertionAttempt = false var isUnique = true val existingInCache = cache.get(key) { // Thread safe, if multiple threads may wait until the first one has loaded. insertionAttempt = true - // Key wasn't in the cache and might be in the underlying storage. - // Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key. - val existingInDb = store(key, value) - if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key. - Optional.of(existingInDb) - } else { - Optional.of(value) - } + // Value wasn't in the cache and wasn't in DB (because the cache is unbound). + // Store the value, depending on store implementation this may replace existing entry in DB. + store(key, value) + Optional.of(value) } if (!insertionAttempt) { if (existingInCache.isPresent) { - // Key already exists in cache, do nothing. + // Key already exists in cache, store the new value in the DB (depends on tore implementation) and refresh cache. isUnique = false + replace(key, value) } else { // This happens when the key was queried before with no value associated. We invalidate the cached null // value and recursively call set again. This is to avoid race conditions where another thread queries after // the invalidate but before the set. cache.invalidate(key) - return set(key, value, logWarning, store) + return set(key, value, logWarning, store, replace) } } if (logWarning && !isUnique) { @@ -88,30 +93,65 @@ class PersistentMap ( /** * Associates the specified value with the specified key in this map and persists it. - * If the map previously contained a mapping for the key, the behaviour is unpredictable and may throw an error from the underlying storage. + * WARNING! If the map previously contained a mapping for the key, the behaviour is unpredictable and may throw an error from the underlying storage. */ operator fun set(key: K, value: V) = - set(key, value, logWarning = false) { - key,value -> DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) - null - } + set(key, value, + logWarning = false, + store = { key: K, value: V -> + DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + }, + replace = { _: K, _: V -> Unit } + ) /** * Associates the specified value with the specified key in this map and persists it. - * If the map previously contained a mapping for the key, the old value is not replaced. + * WARNING! If the map previously contained a mapping for the key, the old value is not replaced. * @return true if added key was unique, otherwise false */ - fun addWithDuplicatesAllowed(key: K, value: V): Boolean = - set(key, value) { - key, value -> - val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) - if (existingEntry == null) { - DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) - null - } else { - fromPersistentEntity(existingEntry).second - } - } + fun addWithDuplicatesAllowed(key: K, value: V) = + set(key, value, + store = { key, value -> + val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + if (existingEntry == null) { + DatabaseTransactionManager.current().session.save(toPersistentEntity(key, value)) + null + } else { + fromPersistentEntity(existingEntry).second + } + }, + replace = { _: K, _: V -> Unit } + ) + + /** + * Associates the specified value with the specified key in this map and persists it. + * @return true if added key was unique, otherwise false + */ + fun addWithDuplicatesReplaced(key: K, value: V) = + set(key, value, + logWarning = false, + store = { k: K, v: V -> merge(k, v) }, + replace = { k: K, v: V -> replaceValue(k, v) } + ) + + private fun replaceValue(key: K, value: V) { + synchronized(this) { + merge(key, value) + cache.put(key, Optional.of(value)) + } + } + + private fun merge(key: K, value: V): V? { + val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + return if (existingEntry != null) { + DatabaseTransactionManager.current().session.merge(toPersistentEntity(key,value)) + fromPersistentEntity(existingEntry).second + } else { + DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + } + } private fun loadValue(key: K): V? { val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) @@ -121,9 +161,92 @@ class PersistentMap ( /** * Removes the mapping for the specified key from this map and underlying storage if present. */ - fun remove(key: K): V? { + override fun remove(key: K): V? { val result = cache.get(key).orElse(null) cache.invalidate(key) return result } + + private class NotReallyMutableEntry(key: K, value: V) : AbstractMap.SimpleImmutableEntry(key, value), MutableMap.MutableEntry { + override fun setValue(newValue: V): V { + throw UnsupportedOperationException("Not really mutable. Implement if really required.") + } + } + + private inner class EntryIterator : MutableIterator> { + private val iterator = all().map { NotReallyMutableEntry(it.first, it.second) }.iterator() + + private var current: MutableMap.MutableEntry? = null + + override fun hasNext(): Boolean = iterator.hasNext() + + override fun next(): MutableMap.MutableEntry { + val extractedNext = iterator.next() + current = extractedNext + return extractedNext + } + + override fun remove() { + val savedCurrent = current ?: throw IllegalStateException("Not called next() yet or already removed.") + current = null + remove(savedCurrent.key) + } + } + + override val keys: MutableSet get() { + return object : AbstractSet() { + override val size: Int get() = this@PersistentMap.size + override fun iterator(): MutableIterator { + return object : MutableIterator { + private val entryIterator = EntryIterator() + + override fun hasNext(): Boolean = entryIterator.hasNext() + override fun next(): K = entryIterator.next().key + override fun remove() { + entryIterator.remove() + } + } + } + } + } + + override val values: MutableCollection get() { + return object : AbstractCollection() { + override val size: Int get() = this@PersistentMap.size + override fun iterator(): MutableIterator { + return object : MutableIterator { + private val entryIterator = EntryIterator() + + override fun hasNext(): Boolean = entryIterator.hasNext() + override fun next(): V = entryIterator.next().value + override fun remove() { + entryIterator.remove() + } + } + } + } + } + + override val entries: MutableSet> get() { + return object : AbstractSet>() { + override val size: Int get() = this@PersistentMap.size + override fun iterator(): MutableIterator> { + return object : MutableIterator> { + private val entryIterator = EntryIterator() + + override fun hasNext(): Boolean = entryIterator.hasNext() + override fun next(): MutableMap.MutableEntry = entryIterator.next() + override fun remove() { + entryIterator.remove() + } + } + } + } + } + + override fun put(key: K, value: V): V? { + val old = cache.get(key) + addWithDuplicatesReplaced(key, value) + return old.orElse(null) + } }