Fix up HibernateObserver to allow cascading persistence after bug report (#524)

This commit is contained in:
Rick Parker 2017-04-10 11:33:03 +01:00 committed by GitHub
parent c17fe29a62
commit d31a6fae85
4 changed files with 168 additions and 17 deletions

View File

@ -287,7 +287,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
VaultSoftLockManager(vault, smm) VaultSoftLockManager(vault, smm)
CashBalanceAsMetricsObserver(services, database) CashBalanceAsMetricsObserver(services, database)
ScheduledActivityObserver(services) ScheduledActivityObserver(services)
HibernateObserver(vault, schemas) HibernateObserver(vault.rawUpdates, schemas)
} }
private fun makeInfo(): NodeInfo { private fun makeInfo(): NodeInfo {

View File

@ -3,21 +3,25 @@ package net.corda.node.services.schema
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.node.services.VaultService import net.corda.core.node.services.Vault
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.schemas.QueryableState import net.corda.core.schemas.QueryableState
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.node.services.api.SchemaService import net.corda.node.services.api.SchemaService
import org.hibernate.FlushMode
import org.hibernate.SessionFactory import org.hibernate.SessionFactory
import org.hibernate.boot.MetadataSources
import org.hibernate.boot.model.naming.Identifier import org.hibernate.boot.model.naming.Identifier
import org.hibernate.boot.model.naming.PhysicalNamingStrategyStandardImpl import org.hibernate.boot.model.naming.PhysicalNamingStrategyStandardImpl
import org.hibernate.boot.registry.BootstrapServiceRegistryBuilder
import org.hibernate.cfg.Configuration import org.hibernate.cfg.Configuration
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider
import org.hibernate.engine.jdbc.env.spi.JdbcEnvironment import org.hibernate.engine.jdbc.env.spi.JdbcEnvironment
import org.hibernate.service.UnknownUnwrapTypeException import org.hibernate.service.UnknownUnwrapTypeException
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
import rx.Observable
import java.sql.Connection import java.sql.Connection
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -25,7 +29,7 @@ import java.util.concurrent.ConcurrentHashMap
* A vault observer that extracts Object Relational Mappings for contract states that support it, and persists them with Hibernate. * A vault observer that extracts Object Relational Mappings for contract states that support it, and persists them with Hibernate.
*/ */
// TODO: Manage version evolution of the schemas via additional tooling. // TODO: Manage version evolution of the schemas via additional tooling.
class HibernateObserver(vaultService: VaultService, val schemaService: SchemaService) { class HibernateObserver(vaultUpdates: Observable<Vault.Update>, val schemaService: SchemaService) {
companion object { companion object {
val logger = loggerFor<HibernateObserver>() val logger = loggerFor<HibernateObserver>()
} }
@ -37,7 +41,7 @@ class HibernateObserver(vaultService: VaultService, val schemaService: SchemaSer
schemaService.schemaOptions.map { it.key }.forEach { schemaService.schemaOptions.map { it.key }.forEach {
makeSessionFactoryForSchema(it) makeSessionFactoryForSchema(it)
} }
vaultService.rawUpdates.subscribe { persist(it.produced) } vaultUpdates.subscribe { persist(it.produced) }
} }
private fun sessionFactoryForSchema(schema: MappedSchema): SessionFactory { private fun sessionFactoryForSchema(schema: MappedSchema): SessionFactory {
@ -46,10 +50,12 @@ class HibernateObserver(vaultService: VaultService, val schemaService: SchemaSer
private fun makeSessionFactoryForSchema(schema: MappedSchema): SessionFactory { private fun makeSessionFactoryForSchema(schema: MappedSchema): SessionFactory {
logger.info("Creating session factory for schema $schema") logger.info("Creating session factory for schema $schema")
val serviceRegistry = BootstrapServiceRegistryBuilder().build()
val metadataSources = MetadataSources(serviceRegistry)
// We set a connection provider as the auto schema generation requires it. The auto schema generation will not // We set a connection provider as the auto schema generation requires it. The auto schema generation will not
// necessarily remain and would likely be replaced by something like Liquibase. For now it is very convenient though. // necessarily remain and would likely be replaced by something like Liquibase. For now it is very convenient though.
// TODO: replace auto schema generation as it isn't intended for production use, according to Hibernate docs. // TODO: replace auto schema generation as it isn't intended for production use, according to Hibernate docs.
val config = Configuration().setProperty("hibernate.connection.provider_class", NodeDatabaseConnectionProvider::class.java.name) val config = Configuration(metadataSources).setProperty("hibernate.connection.provider_class", NodeDatabaseConnectionProvider::class.java.name)
.setProperty("hibernate.hbm2ddl.auto", "update") .setProperty("hibernate.hbm2ddl.auto", "update")
.setProperty("hibernate.show_sql", "false") .setProperty("hibernate.show_sql", "false")
.setProperty("hibernate.format_sql", "true") .setProperty("hibernate.format_sql", "true")
@ -61,14 +67,8 @@ class HibernateObserver(vaultService: VaultService, val schemaService: SchemaSer
} }
val tablePrefix = options?.tablePrefix ?: "contract_" // We always have this as the default for aesthetic reasons. val tablePrefix = options?.tablePrefix ?: "contract_" // We always have this as the default for aesthetic reasons.
logger.debug { "Table prefix = $tablePrefix" } logger.debug { "Table prefix = $tablePrefix" }
config.setPhysicalNamingStrategy(object : PhysicalNamingStrategyStandardImpl() {
override fun toPhysicalTableName(name: Identifier?, context: JdbcEnvironment?): Identifier {
val default = super.toPhysicalTableName(name, context)
return Identifier.toIdentifier(tablePrefix + default.text, default.isQuoted)
}
})
schema.mappedTypes.forEach { config.addAnnotatedClass(it) } schema.mappedTypes.forEach { config.addAnnotatedClass(it) }
val sessionFactory = config.buildSessionFactory() val sessionFactory = buildSessionFactory(config, metadataSources, tablePrefix)
logger.info("Created session factory for schema $schema") logger.info("Created session factory for schema $schema")
return sessionFactory return sessionFactory
} }
@ -87,11 +87,36 @@ class HibernateObserver(vaultService: VaultService, val schemaService: SchemaSer
private fun persistStateWithSchema(state: QueryableState, stateRef: StateRef, schema: MappedSchema) { private fun persistStateWithSchema(state: QueryableState, stateRef: StateRef, schema: MappedSchema) {
val sessionFactory = sessionFactoryForSchema(schema) val sessionFactory = sessionFactoryForSchema(schema)
val session = sessionFactory.openStatelessSession(TransactionManager.current().connection) val session = sessionFactory.withOptions().
connection(TransactionManager.current().connection).
flushMode(FlushMode.MANUAL).
openSession()
session.use { session.use {
val mappedObject = schemaService.generateMappedObject(state, schema) val mappedObject = schemaService.generateMappedObject(state, schema)
mappedObject.stateRef = PersistentStateRef(stateRef) mappedObject.stateRef = PersistentStateRef(stateRef)
session.insert(mappedObject) it.persist(mappedObject)
it.flush()
}
}
private fun buildSessionFactory(config: Configuration, metadataSources: MetadataSources, tablePrefix: String): SessionFactory {
config.standardServiceRegistryBuilder.applySettings(config.properties)
val metadata = metadataSources.getMetadataBuilder(config.standardServiceRegistryBuilder.build()).run {
applyPhysicalNamingStrategy(object : PhysicalNamingStrategyStandardImpl() {
override fun toPhysicalTableName(name: Identifier?, context: JdbcEnvironment?): Identifier {
val default = super.toPhysicalTableName(name, context)
return Identifier.toIdentifier(tablePrefix + default.text, default.isQuoted)
}
})
build()
}
return metadata.sessionFactoryBuilder.run {
allowOutOfTransactionUpdateOperations(true)
applySecondLevelCacheSupport(false)
applyQueryCacheSupport(false)
enableReleaseResourcesOnCloseEnabled(true)
build()
} }
} }

View File

@ -0,0 +1,127 @@
package net.corda.node.services
import net.corda.core.contracts.Contract
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.SecureHash
import net.corda.core.node.services.Vault
import net.corda.core.schemas.MappedSchema
import net.corda.core.schemas.PersistentState
import net.corda.core.schemas.QueryableState
import net.corda.core.utilities.LogHelper
import net.corda.node.services.api.SchemaService
import net.corda.node.services.schema.HibernateObserver
import net.corda.node.utilities.configureDatabase
import net.corda.node.utilities.databaseTransaction
import net.corda.testing.MEGA_CORP
import net.corda.testing.node.makeTestDataSourceProperties
import org.hibernate.annotations.Cascade
import org.hibernate.annotations.CascadeType
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.subjects.PublishSubject
import java.io.Closeable
import javax.persistence.*
import kotlin.test.assertEquals
class HibernateObserverTests {
lateinit var dataSource: Closeable
lateinit var database: Database
@Before
fun setUp() {
LogHelper.setLevel(HibernateObserver::class)
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
}
@After
fun cleanUp() {
dataSource.close()
LogHelper.reset(HibernateObserver::class)
}
class SchemaFamily
@Entity
@Table(name = "Parents")
class Parent : PersistentState() {
@OneToMany(fetch = FetchType.LAZY)
@JoinColumns(JoinColumn(name = "transaction_id"), JoinColumn(name = "output_index"))
@OrderColumn
@Cascade(CascadeType.PERSIST)
var children: MutableSet<Child> = mutableSetOf()
}
@Suppress("unused")
@Entity
@Table(name = "Children")
class Child {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "child_id", unique = true, nullable = false)
var childId: Int? = null
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumns(JoinColumn(name = "transaction_id"), JoinColumn(name = "output_index"))
var parent: Parent? = null
}
class TestState : QueryableState {
override fun supportedSchemas(): Iterable<MappedSchema> {
throw UnsupportedOperationException()
}
override fun generateMappedObject(schema: MappedSchema): PersistentState {
throw UnsupportedOperationException()
}
override val contract: Contract
get() = throw UnsupportedOperationException()
override val participants: List<CompositeKey>
get() = throw UnsupportedOperationException()
}
// This method does not use back quotes for a nice name since it seems to kill the kotlin compiler.
@Test
fun testChildObjectsArePersisted() {
val testSchema = object : MappedSchema(SchemaFamily::class.java, 1, setOf(Parent::class.java, Child::class.java)) {}
val rawUpdatesPublisher = PublishSubject.create<Vault.Update>()
val schemaService = object : SchemaService {
override val schemaOptions: Map<MappedSchema, SchemaService.SchemaOptions> = emptyMap()
override fun selectSchemas(state: QueryableState): Iterable<MappedSchema> = setOf(testSchema)
override fun generateMappedObject(state: QueryableState, schema: MappedSchema): PersistentState {
val parent = Parent()
parent.children.add(Child())
parent.children.add(Child())
return parent
}
}
@Suppress("UNUSED_VARIABLE")
val observer = HibernateObserver(rawUpdatesPublisher, schemaService)
databaseTransaction(database) {
rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0)))))
val parentRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from contract_Parents").executeQuery()
parentRowCountResult.next()
val parentRows = parentRowCountResult.getInt(1)
parentRowCountResult.close()
val childrenRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from contract_Children").executeQuery()
childrenRowCountResult.next()
val childrenRows = childrenRowCountResult.getInt(1)
childrenRowCountResult.close()
assertEquals(1, parentRows, "Expected one parent")
assertEquals(2, childrenRows, "Expected two children")
}
}
}

View File

@ -17,8 +17,8 @@ import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.node.services.persistence.InMemoryStateMachineRecordedTransactionMappingStorage import net.corda.node.services.persistence.InMemoryStateMachineRecordedTransactionMappingStorage
import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.transactions.InMemoryTransactionVerifierService
import net.corda.node.services.vault.NodeVaultService
import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP
import net.corda.testing.MINI_CORP import net.corda.testing.MINI_CORP
import net.corda.testing.MOCK_VERSION import net.corda.testing.MOCK_VERSION
@ -28,7 +28,6 @@ import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.File import java.io.File
import java.io.InputStream import java.io.InputStream
import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import java.security.KeyPair import java.security.KeyPair
import java.security.PrivateKey import java.security.PrivateKey
@ -74,7 +73,7 @@ open class MockServices(val key: KeyPair = generateKeyPair()) : ServiceHub {
fun makeVaultService(dataSourceProps: Properties): VaultService { fun makeVaultService(dataSourceProps: Properties): VaultService {
val vaultService = NodeVaultService(this, dataSourceProps) val vaultService = NodeVaultService(this, dataSourceProps)
// Vault cash spending requires access to contract_cash_states and their updates // Vault cash spending requires access to contract_cash_states and their updates
HibernateObserver(vaultService, NodeSchemaService()) HibernateObserver(vaultService.rawUpdates, NodeSchemaService())
return vaultService return vaultService
} }
} }