rewrite RaftUniquenessProvider to use Hibernate

This commit is contained in:
szymonsztuka 2017-08-24 10:01:04 +01:00 committed by GitHub
parent d6d7fb52b4
commit d0a3aa3fc7
11 changed files with 93 additions and 35 deletions

View File

@ -10,8 +10,7 @@ import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.parserTransactionIsolationLevel
import org.hibernate.SessionFactory
import org.hibernate.boot.MetadataSources
import org.hibernate.boot.model.naming.Identifier
import org.hibernate.boot.model.naming.PhysicalNamingStrategyStandardImpl
import org.hibernate.boot.model.naming.*
import org.hibernate.boot.registry.BootstrapServiceRegistryBuilder
import org.hibernate.cfg.Configuration
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider
@ -67,7 +66,8 @@ class HibernateConfiguration(createSchemaService: () -> SchemaService, private v
// TODO: require mechanism to set schemaOptions (databaseSchema, tablePrefix) which are not global to session
schema.mappedTypes.forEach { config.addAnnotatedClass(it) }
}
val sessionFactory = buildSessionFactory(config, metadataSources, "")
val sessionFactory = buildSessionFactory(config, metadataSources, databaseProperties.getProperty("serverNameTablePrefix",""))
logger.info("Created session factory for schemas: $schemas")
return sessionFactory
}

View File

@ -18,6 +18,7 @@ import net.corda.node.services.persistence.DBTransactionMappingStorage
import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.transactions.RaftUniquenessProvider
import net.corda.node.services.vault.VaultSchemaV1
/**
@ -44,7 +45,9 @@ class NodeSchemaService(customSchemas: Set<MappedSchema> = emptySet()) : SchemaS
PersistentNetworkMapService.NetworkNode::class.java,
PersistentNetworkMapService.NetworkSubscriber::class.java,
NodeMessagingClient.ProcessedMessage::class.java,
NodeMessagingClient.RetryMessage::class.java
NodeMessagingClient.RetryMessage::class.java,
NodeAttachmentService.DBAttachment::class.java,
RaftUniquenessProvider.RaftState::class.java
))
// Required schemas are those used by internal Corda services

View File

@ -8,9 +8,8 @@ import io.atomix.copycat.server.StateMachine
import io.atomix.copycat.server.storage.snapshot.SnapshotReader
import io.atomix.copycat.server.storage.snapshot.SnapshotWriter
import net.corda.core.utilities.loggerFor
import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.JDBCHashMap
import java.util.*
import net.corda.node.utilities.*
import java.util.LinkedHashMap
/**
* A distributed map state machine that doesn't allow overriding values. The state machine is replicated
@ -20,9 +19,9 @@ import java.util.*
* to disk, and sharing them across the cluster. A new node joining the cluster will have to obtain and install a snapshot
* containing the entire JDBC table contents.
*/
class DistributedImmutableMap<K : Any, V : Any>(val db: CordaPersistence, tableName: String) : StateMachine(), Snapshottable {
class DistributedImmutableMap<K : Any, V : Any, E, EK>(val db: CordaPersistence, createMap: () -> AppendOnlyPersistentMap<K, V, E, EK>) : StateMachine(), Snapshottable {
companion object {
private val log = loggerFor<DistributedImmutableMap<*, *>>()
private val log = loggerFor<DistributedImmutableMap<*, *, *, *>>()
}
object Commands {
@ -38,7 +37,7 @@ class DistributedImmutableMap<K : Any, V : Any>(val db: CordaPersistence, tableN
class Get<out K, V>(val key: K) : Query<V?>
}
private val map = db.transaction { JDBCHashMap<K, V>(tableName) }
private val map = db.transaction { createMap() }
/** Gets a value for the given [Commands.Get.key] */
fun get(commit: Commit<Commands.Get<K, V>>): V? {
@ -80,7 +79,7 @@ class DistributedImmutableMap<K : Any, V : Any>(val db: CordaPersistence, tableN
override fun snapshot(writer: SnapshotWriter) {
db.transaction {
writer.writeInt(map.size)
map.entries.forEach { writer.writeObject(it.key to it.value) }
map.allPersisted().forEach { writer.writeObject(it.first to it.second) }
}
}
@ -92,7 +91,7 @@ class DistributedImmutableMap<K : Any, V : Any>(val db: CordaPersistence, tableN
// TODO: read & put entries in batches
for (i in 1..size) {
val (key, value) = reader.readObject<Pair<K, V>>()
map.put(key, value)
map[key] = value
}
}
}

View File

@ -18,16 +18,22 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import net.corda.core.node.services.UniquenessException
import net.corda.core.node.services.UniquenessProvider
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.loggerFor
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.node.utilities.CordaPersistence
import net.corda.nodeapi.config.SSLConfiguration
import java.nio.file.Path
import java.util.concurrent.CompletableFuture
import javax.annotation.concurrent.ThreadSafe
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Lob
/**
* A uniqueness provider that records committed input states in a distributed collection replicated and
@ -41,9 +47,35 @@ import javax.annotation.concurrent.ThreadSafe
class RaftUniquenessProvider(services: ServiceHubInternal) : UniquenessProvider, SingletonSerializeAsToken() {
companion object {
private val log = loggerFor<RaftUniquenessProvider>()
private val DB_TABLE_NAME = "notary_committed_states"
fun createMap(): AppendOnlyPersistentMap<String, Any, RaftState, String> =
AppendOnlyPersistentMap(
toPersistentEntityKey = { it },
fromPersistentEntity = {
Pair(it.key, it.value.deserialize(context = SerializationDefaults.STORAGE_CONTEXT))
},
toPersistentEntity = { k: String, v: Any ->
RaftState().apply {
key = k
value = v.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
}
},
persistentEntityClass = RaftState::class.java
)
}
@Entity
@javax.persistence.Table(name = "notary_committed_states")
class RaftState(
@Id
@Column
var key: String = "",
@Lob
@Column
var value: ByteArray = ByteArray(0)
)
/** Directory storing the Raft log and state machine snapshots */
private val storagePath: Path = services.configuration.baseDirectory
/** Address of the Copycat node run by this Corda node */
@ -70,7 +102,8 @@ class RaftUniquenessProvider(services: ServiceHubInternal) : UniquenessProvider,
fun start() {
log.info("Creating Copycat server, log stored in: ${storagePath.toFile()}")
val stateMachineFactory = { DistributedImmutableMap<String, ByteArray>(db, DB_TABLE_NAME) }
val stateMachineFactory = {
DistributedImmutableMap(db, RaftUniquenessProvider.Companion::createMap) }
val address = Address(myAddress.host, myAddress.port)
val storage = buildStorage(storagePath)
val transport = buildTransport(transportConfiguration)

View File

@ -34,6 +34,8 @@ class AppendOnlyPersistentMap<K, V, E, out EK> (
return cache.get(key).orElse(null)
}
val size get() = allPersisted().toList().size
/**
* Returns all key/value pairs from the underlying storage.
*/
@ -105,10 +107,28 @@ class AppendOnlyPersistentMap<K, V, E, out EK> (
}
}
fun putAll(entries: Map<K,V>) {
entries.forEach {
set(it.key, it.value)
}
}
private fun loadValue(key: K): V? {
val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
}
operator fun contains(key: K) = get(key) != null
/**
* Removes all of the mappings from this map and underlying storage. The map will be empty after this call returns.
* WARNING!! The method is not thread safe.
*/
fun clear() {
val session = DatabaseTransactionManager.current().session
val deleteQuery = session.criteriaBuilder.createCriteriaDelete(persistentEntityClass)
deleteQuery.from(persistentEntityClass)
session.createQuery(deleteQuery).executeUpdate()
cache.invalidateAll()
}
}

View File

@ -8,18 +8,17 @@ class NonInvalidatingUnboundCache<K, V> private constructor(
val cache: LoadingCache<K, V>
): LoadingCache<K, V> by cache {
constructor(concurrencyLevel: Int, loadFunction: (K) -> V) :
this(buildCache(concurrencyLevel, loadFunction, RemovalListener<K, V> {
//no removal
}))
constructor(concurrencyLevel: Int, loadFunction: (K) -> V, removalListener: RemovalListener<K, V>) :
this(buildCache(concurrencyLevel, loadFunction, removalListener))
constructor(concurrencyLevel: Int, loadFunction: (K) -> V, removalListener: RemovalListener<K, V> = RemovalListener {},
keysToPreload: () -> Iterable<K> = { emptyList() } ) :
this(buildCache(concurrencyLevel, loadFunction, removalListener, keysToPreload))
private companion object {
private fun <K, V> buildCache(concurrencyLevel: Int, loadFunction: (K) -> V, removalListener: RemovalListener<K, V>): LoadingCache<K, V> {
private fun <K, V> buildCache(concurrencyLevel: Int, loadFunction: (K) -> V, removalListener: RemovalListener<K, V>,
keysToPreload: () -> Iterable<K>): LoadingCache<K, V> {
val builder = CacheBuilder.newBuilder().concurrencyLevel(concurrencyLevel).removalListener(removalListener)
return builder.build(NonInvalidatingCacheLoader(loadFunction))
return builder.build(NonInvalidatingCacheLoader(loadFunction)).apply {
getAll(keysToPreload())
}
}
}

View File

@ -60,7 +60,7 @@ class PersistentMap<K, V, E, out EK> (
return cache.asMap().asSequence().map { Pair(it.key, it.value.get()) }
}
override val size = cache.size().toInt()
override val size get() = cache.size().toInt()
private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?, replace: (K, V) -> Unit) : Boolean {
var insertionAttempt = false

View File

@ -28,13 +28,12 @@ class DistributedImmutableMapTests : TestDependencyInjectionBase() {
lateinit var cluster: List<Member>
lateinit var transaction: Transaction
lateinit var database: CordaPersistence
private val databases: MutableList<CordaPersistence> = mutableListOf()
@Before
fun setup() {
LogHelper.setLevel("-org.apache.activemq")
LogHelper.setLevel(NetworkMapService::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), createIdentityService = ::makeTestIdentityService)
cluster = setUpCluster()
}
@ -46,7 +45,7 @@ class DistributedImmutableMapTests : TestDependencyInjectionBase() {
it.client.close()
it.server.shutdown()
}
database.close()
databases.forEach { it.close() }
}
@Test
@ -87,8 +86,9 @@ class DistributedImmutableMapTests : TestDependencyInjectionBase() {
private fun createReplica(myAddress: NetworkHostAndPort, clusterAddress: NetworkHostAndPort? = null): CompletableFuture<Member> {
val storage = Storage.builder().withStorageLevel(StorageLevel.MEMORY).build()
val address = Address(myAddress.host, myAddress.port)
val stateMachineFactory = { DistributedImmutableMap<String, ByteArray>(database, "commited_states_${myAddress.port}") }
val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties("serverNameTablePrefix", "PORT_${myAddress.port}_"), createIdentityService = ::makeTestIdentityService)
databases.add(database)
val stateMachineFactory = { DistributedImmutableMap(database, RaftUniquenessProvider.Companion::createMap) }
val server = CopycatServer.builder(address)
.withStateMachine(stateMachineFactory)

View File

@ -630,13 +630,15 @@ class DriverDSL(
advertisedServices = advertisedServices,
rpcUsers = rpcUsers,
verifierType = verifierType,
customOverrides = mapOf("notaryNodeAddress" to notaryClusterAddress.toString()),
customOverrides = mapOf("notaryNodeAddress" to notaryClusterAddress.toString(),
"database.serverNameTablePrefix" to if (nodeNames.isNotEmpty()) nodeNames.first().toString().replace(Regex("[^0-9A-Za-z]+"),"") else ""),
startInSameProcess = startInSameProcess
)
// All other nodes will join the cluster
val restNotaryFutures = nodeNames.drop(1).map {
val nodeAddress = portAllocation.nextHostAndPort()
val configOverride = mapOf("notaryNodeAddress" to nodeAddress.toString(), "notaryClusterAddresses" to listOf(notaryClusterAddress.toString()))
val configOverride = mapOf("notaryNodeAddress" to nodeAddress.toString(), "notaryClusterAddresses" to listOf(notaryClusterAddress.toString()),
"database.serverNameTablePrefix" to it.toString().replace(Regex("[^0-9A-Za-z]+"), ""))
startNode(it, advertisedServices, rpcUsers, verifierType, configOverride)
}

View File

@ -209,9 +209,10 @@ fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().to
return props
}
fun makeTestDatabaseProperties(): Properties {
fun makeTestDatabaseProperties(key: String? = null, value: String? = null): Properties {
val props = Properties()
props.setProperty("transactionIsolationLevel", "repeatableRead") //for other possible values see net.corda.node.utilities.CordaPeristence.parserTransactionIsolationLevel(String)
if (key != null) { props.setProperty(key, value) }
return props
}

View File

@ -15,7 +15,6 @@ import net.corda.core.node.services.ServiceType
import net.corda.core.utilities.WHITESPACE
import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.Node
import net.corda.node.serialization.NodeClock
import net.corda.node.services.config.ConfigHelper
import net.corda.node.services.config.FullNodeConfiguration
import net.corda.node.services.config.configOf
@ -125,7 +124,8 @@ abstract class NodeBasedTest : TestDependencyInjectionBase() {
val masterNodeFuture = startNode(
getX509Name("${notaryName.commonName}-0", "London", "demo@r3.com", null),
advertisedServices = setOf(serviceInfo),
configOverrides = mapOf("notaryNodeAddress" to nodeAddresses[0]))
configOverrides = mapOf("notaryNodeAddress" to nodeAddresses[0],
"database" to mapOf("serverNameTablePrefix" to if (clusterSize > 1) "${notaryName.commonName}0".replace(Regex("[^0-9A-Za-z]+"),"") else "")))
val remainingNodesFutures = (1 until clusterSize).map {
startNode(
@ -133,7 +133,8 @@ abstract class NodeBasedTest : TestDependencyInjectionBase() {
advertisedServices = setOf(serviceInfo),
configOverrides = mapOf(
"notaryNodeAddress" to nodeAddresses[it],
"notaryClusterAddresses" to listOf(nodeAddresses[0])))
"notaryClusterAddresses" to listOf(nodeAddresses[0]),
"database" to mapOf("serverNameTablePrefix" to "${notaryName.commonName}$it".replace(Regex("[^0-9A-Za-z]+"), ""))))
}
return remainingNodesFutures.transpose().flatMap { remainingNodes ->