diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt index 5303b202f3..3e818b4f76 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt @@ -1,13 +1,16 @@ package net.corda.nodeapi.internal.persistence +import co.paralleluniverse.strands.Strand import net.corda.core.schemas.MappedSchema import net.corda.core.utilities.contextLogger import rx.Observable import rx.Subscriber +import rx.subjects.PublishSubject import rx.subjects.UnicastSubject import java.io.Closeable import java.sql.Connection import java.sql.SQLException +import java.util.* import java.util.concurrent.CopyOnWriteArrayList import javax.persistence.AttributeConverter import javax.sql.DataSource @@ -40,6 +43,9 @@ enum class TransactionIsolationLevel { val jdbcValue: Int = java.sql.Connection::class.java.getField(jdbcString).get(null) as Int } +private val _contextDatabase = ThreadLocal() +val contextDatabase get() = _contextDatabase.get() ?: error("Was expecting to find CordaPersistence set on current thread: ${Strand.currentStrand()}") + class CordaPersistence( val dataSource: DataSource, databaseConfig: DatabaseConfig, @@ -51,7 +57,7 @@ class CordaPersistence( private val log = contextLogger() } - val defaultIsolationLevel = databaseConfig.transactionIsolationLevel + private val defaultIsolationLevel = databaseConfig.transactionIsolationLevel val hibernateConfig: HibernateConfiguration by lazy { transaction { @@ -60,8 +66,19 @@ class CordaPersistence( } val entityManagerFactory get() = hibernateConfig.sessionFactoryForRegisteredSchemas + data class Boundary(val txId: UUID) + + internal val transactionBoundaries = PublishSubject.create().toSerialized() + init { - DatabaseTransactionManager(this) + // Found a unit test that was forgetting to close the database transactions. When you close() on the top level + // database transaction it will reset the threadLocalTx back to null, so if it isn't then there is still a + // database transaction open. The [transaction] helper above handles this in a finally clause for you + // but any manual database transaction management is liable to have this problem. + contextTransactionOrNull?.let { + error("Was not expecting to find existing database transaction on current strand when setting database: ${Strand.currentStrand()}, $it") + } + _contextDatabase.set(this) // Check not in read-only mode. transaction { check(!connection.metaData.isReadOnly) { "Database should not be readonly." } @@ -72,25 +89,29 @@ class CordaPersistence( const val DATA_SOURCE_URL = "dataSource.url" } - /** - * Creates an instance of [DatabaseTransaction], with the given transaction isolation level. - */ - fun createTransaction(isolationLevel: TransactionIsolationLevel): DatabaseTransaction { - // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. - DatabaseTransactionManager.dataSource = this - return DatabaseTransactionManager.currentOrNew(isolationLevel) + fun currentOrNew(isolation: TransactionIsolationLevel = defaultIsolationLevel): DatabaseTransaction { + return contextTransactionOrNull ?: newTransaction(isolation) + } + + fun newTransaction(isolation: TransactionIsolationLevel = defaultIsolationLevel): DatabaseTransaction { + return DatabaseTransaction(isolation.jdbcValue, contextTransactionOrNull, this).also { + contextTransactionOrNull = it + } } /** - * Creates an instance of [DatabaseTransaction], with the default transaction isolation level. + * Creates an instance of [DatabaseTransaction], with the given transaction isolation level. */ - fun createTransaction(): DatabaseTransaction = createTransaction(defaultIsolationLevel) + fun createTransaction(isolationLevel: TransactionIsolationLevel = defaultIsolationLevel): DatabaseTransaction { + // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. + _contextDatabase.set(this) + return currentOrNew(isolationLevel) + } fun createSession(): Connection { // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. - DatabaseTransactionManager.dataSource = this - val ctx = DatabaseTransactionManager.currentOrNull() - return ctx?.connection ?: throw IllegalStateException("Was expecting to find database transaction: must wrap calling code within a transaction.") + _contextDatabase.set(this) + return contextTransaction.connection } /** @@ -99,7 +120,7 @@ class CordaPersistence( * @param statement to be executed in the scope of this transaction. */ fun transaction(isolationLevel: TransactionIsolationLevel, statement: DatabaseTransaction.() -> T): T { - DatabaseTransactionManager.dataSource = this + _contextDatabase.set(this) return transaction(isolationLevel, 2, statement) } @@ -110,7 +131,7 @@ class CordaPersistence( fun transaction(statement: DatabaseTransaction.() -> T): T = transaction(defaultIsolationLevel, statement) private fun transaction(isolationLevel: TransactionIsolationLevel, recoverableFailureTolerance: Int, statement: DatabaseTransaction.() -> T): T { - val outer = DatabaseTransactionManager.currentOrNull() + val outer = contextTransactionOrNull return if (outer != null) { outer.statement() } else { @@ -126,7 +147,7 @@ class CordaPersistence( log.warn("Cleanup task failed:", t) } while (true) { - val transaction = DatabaseTransactionManager.currentOrNew(isolationLevel) + val transaction = contextDatabase.currentOrNew(isolationLevel) // XXX: Does this code really support statement changing the contextDatabase? try { val answer = transaction.statement() transaction.commit() @@ -160,8 +181,8 @@ class CordaPersistence( * For examples, see the call hierarchy of this function. */ fun rx.Observer.bufferUntilDatabaseCommit(): rx.Observer { - val currentTxId = DatabaseTransactionManager.transactionId - val databaseTxBoundary: Observable = DatabaseTransactionManager.transactionBoundaries.first { it.txId == currentTxId } + val currentTxId = contextTransaction.id + val databaseTxBoundary: Observable = contextDatabase.transactionBoundaries.first { it.txId == currentTxId } val subject = UnicastSubject.create() subject.delaySubscription(databaseTxBoundary).subscribe(this) databaseTxBoundary.doOnCompleted { subject.onCompleted() } @@ -169,12 +190,12 @@ fun rx.Observer.bufferUntilDatabaseCommit(): rx.Observer { } // A subscriber that delegates to multiple others, wrapping a database transaction around the combination. -private class DatabaseTransactionWrappingSubscriber(val db: CordaPersistence?) : Subscriber() { +private class DatabaseTransactionWrappingSubscriber(private val db: CordaPersistence?) : Subscriber() { // Some unsubscribes happen inside onNext() so need something that supports concurrent modification. val delegates = CopyOnWriteArrayList>() fun forEachSubscriberWithDbTx(block: Subscriber.() -> Unit) { - (db ?: DatabaseTransactionManager.dataSource).transaction { + (db ?: contextDatabase).transaction { delegates.filter { !it.isUnsubscribed }.forEach { it.block() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt index 4bb1bbd42e..a1c16fa9eb 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransaction.kt @@ -1,23 +1,29 @@ package net.corda.nodeapi.internal.persistence +import co.paralleluniverse.strands.Strand import org.hibernate.Session import org.hibernate.Transaction -import rx.subjects.Subject import java.sql.Connection import java.util.* +fun currentDBSession(): Session = contextTransaction.session +private val _contextTransaction = ThreadLocal() +var contextTransactionOrNull: DatabaseTransaction? + get() = _contextTransaction.get() + set(transaction) = _contextTransaction.set(transaction) +val contextTransaction get() = contextTransactionOrNull ?: error("Was expecting to find transaction set on current strand: ${Strand.currentStrand()}") + class DatabaseTransaction( isolation: Int, - private val threadLocal: ThreadLocal, - private val transactionBoundaries: Subject, - val cordaPersistence: CordaPersistence + private val outerTransaction: DatabaseTransaction?, + val database: CordaPersistence ) { val id: UUID = UUID.randomUUID() private var _connectionCreated = false val connectionCreated get() = _connectionCreated val connection: Connection by lazy(LazyThreadSafetyMode.NONE) { - cordaPersistence.dataSource.connection + database.dataSource.connection .apply { _connectionCreated = true // only set the transaction isolation level if it's actually changed - setting isn't free. @@ -28,16 +34,13 @@ class DatabaseTransaction( } private val sessionDelegate = lazy { - val session = cordaPersistence.entityManagerFactory.withOptions().connection(connection).openSession() + val session = database.entityManagerFactory.withOptions().connection(connection).openSession() hibernateTransaction = session.beginTransaction() session } val session: Session by sessionDelegate private lateinit var hibernateTransaction: Transaction - - val outerTransaction: DatabaseTransaction? = threadLocal.get() - fun commit() { if (sessionDelegate.isInitialized()) { hibernateTransaction.commit() @@ -63,9 +66,9 @@ class DatabaseTransaction( if (_connectionCreated) { connection.close() } - threadLocal.set(outerTransaction) + contextTransactionOrNull = outerTransaction if (outerTransaction == null) { - transactionBoundaries.onNext(DatabaseTransactionManager.Boundary(id)) + database.transactionBoundaries.onNext(CordaPersistence.Boundary(id)) } } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransactionManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransactionManager.kt deleted file mode 100644 index ade1603002..0000000000 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/DatabaseTransactionManager.kt +++ /dev/null @@ -1,77 +0,0 @@ -package net.corda.nodeapi.internal.persistence - -import co.paralleluniverse.strands.Strand -import org.hibernate.Session -import rx.subjects.PublishSubject -import rx.subjects.Subject -import java.util.* -import java.util.concurrent.ConcurrentHashMap - -fun currentDBSession(): Session = DatabaseTransactionManager.current().session - -class DatabaseTransactionManager(initDataSource: CordaPersistence) { - companion object { - private val threadLocalDb = ThreadLocal() - private val threadLocalTx = ThreadLocal() - private val databaseToInstance = ConcurrentHashMap() - - fun setThreadLocalTx(tx: DatabaseTransaction?): DatabaseTransaction? { - val oldTx = threadLocalTx.get() - threadLocalTx.set(tx) - return oldTx - } - - fun restoreThreadLocalTx(context: DatabaseTransaction?) { - if (context != null) { - threadLocalDb.set(context.cordaPersistence) - } - threadLocalTx.set(context) - } - - var dataSource: CordaPersistence - get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find CordaPersistence set on current thread: ${Strand.currentStrand()}") - set(value) = threadLocalDb.set(value) - - val transactionId: UUID - get() = threadLocalTx.get()?.id ?: throw IllegalStateException("Was expecting to find transaction set on current strand: ${Strand.currentStrand()}") - - val manager: DatabaseTransactionManager get() = databaseToInstance[dataSource]!! - - val transactionBoundaries: Subject get() = manager._transactionBoundaries - - fun currentOrNull(): DatabaseTransaction? = manager.currentOrNull() - - fun currentOrNew(isolation: TransactionIsolationLevel = dataSource.defaultIsolationLevel): DatabaseTransaction { - return currentOrNull() ?: manager.newTransaction(isolation.jdbcValue) - } - - fun current(): DatabaseTransaction = currentOrNull() ?: error("No transaction in context.") - - fun newTransaction(isolation: TransactionIsolationLevel = dataSource.defaultIsolationLevel): DatabaseTransaction { - return manager.newTransaction(isolation.jdbcValue) - } - } - - data class Boundary(val txId: UUID) - - private val _transactionBoundaries = PublishSubject.create().toSerialized() - - init { - // Found a unit test that was forgetting to close the database transactions. When you close() on the top level - // database transaction it will reset the threadLocalTx back to null, so if it isn't then there is still a - // database transaction open. The [transaction] helper above handles this in a finally clause for you - // but any manual database transaction management is liable to have this problem. - if (threadLocalTx.get() != null) { - throw IllegalStateException("Was not expecting to find existing database transaction on current strand when setting database: ${Strand.currentStrand()}, ${threadLocalTx.get()}") - } - dataSource = initDataSource - databaseToInstance[dataSource] = this - } - - private fun newTransaction(isolation: Int) = - DatabaseTransaction(isolation, threadLocalTx, transactionBoundaries, dataSource).apply { - threadLocalTx.set(this) - } - - private fun currentOrNull(): DatabaseTransaction? = threadLocalTx.get() -} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/HibernateConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/HibernateConfiguration.kt index 0e17c06e4c..3861edb4a3 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/HibernateConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/HibernateConfiguration.kt @@ -128,15 +128,16 @@ class HibernateConfiguration( class NodeDatabaseConnectionProvider : ConnectionProvider { override fun closeConnection(conn: Connection) { conn.autoCommit = false - val tx = DatabaseTransactionManager.current() - tx.commit() - tx.close() + contextTransaction.run { + commit() + close() + } } override fun supportsAggressiveRelease(): Boolean = true override fun getConnection(): Connection { - return DatabaseTransactionManager.newTransaction().connection + return contextDatabase.newTransaction().connection } override fun unwrap(unwrapType: Class): T { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt index da6dd489f8..60c5eb373a 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt @@ -17,7 +17,6 @@ import net.corda.core.node.services.vault.AttachmentSort import net.corda.core.serialization.* import net.corda.core.utilities.contextLogger import net.corda.node.services.vault.HibernateAttachmentQueryCriteriaParser -import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.currentDBSession import java.io.* @@ -242,8 +241,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single override fun queryAttachments(criteria: AttachmentQueryCriteria, sorting: AttachmentSort?): List { log.info("Attachment query criteria: $criteria, sorting: $sorting") - - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(DBAttachment::class.java) diff --git a/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt b/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt index eb58d0871e..62d5cee4d8 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt @@ -10,8 +10,8 @@ import net.corda.core.schemas.PersistentStateRef import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug import net.corda.node.services.api.SchemaService -import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import net.corda.nodeapi.internal.persistence.HibernateConfiguration +import net.corda.nodeapi.internal.persistence.contextTransaction import org.hibernate.FlushMode import rx.Observable @@ -54,7 +54,7 @@ class HibernateObserver private constructor(private val config: HibernateConfigu internal fun persistStatesWithSchema(statesAndRefs: List, schema: MappedSchema) { val sessionFactory = config.sessionFactoryForSchemas(setOf(schema)) val session = sessionFactory.withOptions(). - connection(DatabaseTransactionManager.current().connection). + connection(contextTransaction.connection). flushMode(FlushMode.MANUAL). openSession() session.use { thisSession -> diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4b0408c622..dca556ea59 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -27,7 +27,8 @@ import net.corda.node.services.statemachine.transitions.FlowContinuation import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseTransaction -import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager +import net.corda.nodeapi.internal.persistence.contextTransaction +import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import org.slf4j.Logger import org.slf4j.LoggerFactory import java.nio.file.Paths @@ -58,8 +59,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } private fun extractThreadLocalTransaction(): TransientReference { - val transaction = DatabaseTransactionManager.current() - DatabaseTransactionManager.setThreadLocalTx(null) + val transaction = contextTransaction + contextTransactionOrNull = null return TransientReference(transaction) } } @@ -234,7 +235,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, parkAndSerialize { _, _ -> logger.trace { "Suspended on $ioRequest" } - DatabaseTransactionManager.setThreadLocalTx(transaction.value) + contextTransactionOrNull = transaction.value val event = try { Event.Suspend( ioRequest = ioRequest, diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index f215d87cee..a731f4841e 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -17,12 +17,8 @@ import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.utilities.* import net.corda.node.services.api.VaultServiceInternal -import net.corda.nodeapi.internal.persistence.HibernateConfiguration import net.corda.node.services.statemachine.FlowStateMachineImpl -import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager -import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit -import net.corda.nodeapi.internal.persistence.currentDBSession -import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction +import net.corda.nodeapi.internal.persistence.* import org.hibernate.Session import rx.Observable import rx.subjects.PublishSubject @@ -479,8 +475,7 @@ class NodeVaultService( } } - private fun getSession() = DatabaseTransactionManager.currentOrNew().session - + private fun getSession() = contextDatabase.currentOrNew().session /** * Derive list from existing vault states and then incrementally update using vault observables */ diff --git a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt index 0a657ad95e..c83b75b89f 100644 --- a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt @@ -15,7 +15,6 @@ import net.corda.core.schemas.QueryableState import net.corda.node.services.api.SchemaService import net.corda.node.internal.configureDatabase import net.corda.nodeapi.internal.persistence.DatabaseConfig -import net.corda.nodeapi.internal.persistence.DatabaseTransactionManager import net.corda.testing.internal.LogHelper import net.corda.testing.TestIdentity import net.corda.testing.contracts.DummyContract @@ -74,11 +73,11 @@ class HibernateObserverTests { database.transaction { val MEGA_CORP = TestIdentity(CordaX500Name("MegaCorp", "London", "GB")).party rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), DummyContract.PROGRAM_ID, MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0))))) - val parentRowCountResult = DatabaseTransactionManager.current().connection.prepareStatement("select count(*) from Parents").executeQuery() + val parentRowCountResult = connection.prepareStatement("select count(*) from Parents").executeQuery() parentRowCountResult.next() val parentRows = parentRowCountResult.getInt(1) parentRowCountResult.close() - val childrenRowCountResult = DatabaseTransactionManager.current().connection.prepareStatement("select count(*) from Children").executeQuery() + val childrenRowCountResult = connection.prepareStatement("select count(*) from Children").executeQuery() childrenRowCountResult.next() val childrenRows = childrenRowCountResult.getInt(1) childrenRowCountResult.close() diff --git a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt index 68918167f0..9aeb4a96d8 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt @@ -16,9 +16,7 @@ import java.io.Closeable import java.util.* class ObservablesTests { - - private fun isInDatabaseTransaction(): Boolean = (DatabaseTransactionManager.currentOrNull() != null) - + private fun isInDatabaseTransaction() = contextTransactionOrNull != null private val toBeClosed = mutableListOf() private fun createDatabase(): CordaPersistence { @@ -168,7 +166,7 @@ class ObservablesTests { observableWithDbTx.first().subscribe { undelayedEvent.set(it to isInDatabaseTransaction()) } fun observeSecondEvent(event: Int, future: SettableFuture>) { - future.set(event to if (isInDatabaseTransaction()) DatabaseTransactionManager.transactionId else null) + future.set(event to if (isInDatabaseTransaction()) contextTransaction.id else null) } observableWithDbTx.skip(1).first().subscribe { observeSecondEvent(it, delayedEventFromSecondObserver) }