diff --git a/build.gradle b/build.gradle index 3fe642542e..1850c8ff6f 100644 --- a/build.gradle +++ b/build.gradle @@ -57,6 +57,7 @@ buildscript { ext.jsr305_version = constants.getProperty("jsr305Version") ext.shiro_version = '1.4.0' ext.artifactory_plugin_version = constants.getProperty('artifactoryPluginVersion') + ext.hikari_version = '2.5.1' ext.liquibase_version = '3.5.5' ext.artifactory_contextUrl = 'https://ci-artifactory.corda.r3cev.com/artifactory' ext.snake_yaml_version = constants.getProperty('snakeYamlVersion') diff --git a/node-api/build.gradle b/node-api/build.gradle index 0f2c1a8c3e..5aee94f5b8 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -27,6 +27,9 @@ dependencies { compile "org.apache.qpid:proton-j:$protonj_version" + // SQL connection pooling library + compile "com.zaxxer:HikariCP:$hikari_version" + // ClassGraph: classpath scanning compile "io.github.classgraph:classgraph:$class_graph_version" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt index 751805d315..ab039bf05f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/CordaPersistence.kt @@ -1,6 +1,9 @@ package net.corda.nodeapi.internal.persistence import co.paralleluniverse.strands.Strand +import com.zaxxer.hikari.HikariDataSource +import com.zaxxer.hikari.pool.HikariPool +import com.zaxxer.hikari.util.ConcurrentBag import net.corda.core.internal.NamedCacheFactory import net.corda.core.schemas.MappedSchema import net.corda.core.utilities.contextLogger @@ -9,9 +12,10 @@ import rx.Observable import rx.Subscriber import rx.subjects.UnicastSubject import java.io.Closeable +import java.lang.reflect.Field import java.sql.Connection import java.sql.SQLException -import java.util.UUID +import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.AtomicInteger @@ -254,6 +258,24 @@ class CordaPersistence( // DataSource doesn't implement AutoCloseable so we just have to hope that the implementation does so that we can close it (_dataSource as? AutoCloseable)?.close() } + + val hikariPoolThreadLocal: ThreadLocal>? by lazy(LazyThreadSafetyMode.PUBLICATION) { + val hikariDataSource = dataSource as? HikariDataSource + if (hikariDataSource == null) { + null + } else { + val poolField: Field = HikariDataSource::class.java.getDeclaredField("pool") + poolField.isAccessible = true + val pool: HikariPool = poolField.get(hikariDataSource) as HikariPool + val connectionBagField: Field = HikariPool::class.java.getDeclaredField("connectionBag") + connectionBagField.isAccessible = true + val connectionBag: ConcurrentBag = connectionBagField.get(pool) as ConcurrentBag + val threadListField: Field = ConcurrentBag::class.java.getDeclaredField("threadList") + threadListField.isAccessible = true + val threadList: ThreadLocal> = threadListField.get(connectionBag) as ThreadLocal> + threadList + } + } } /** diff --git a/node/build.gradle b/node/build.gradle index a4088e80c2..56dc13295a 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -145,7 +145,7 @@ dependencies { compile "org.postgresql:postgresql:$postgresql_version" // SQL connection pooling library - compile "com.zaxxer:HikariCP:2.5.1" + compile "com.zaxxer:HikariCP:${hikari_version}" // Hibernate: an object relational mapper for writing state objects to the database automatically. compile "org.hibernate:hibernate-core:$hibernate_version" diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FiberUtils.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FiberUtils.kt new file mode 100644 index 0000000000..ee50b52bd7 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FiberUtils.kt @@ -0,0 +1,15 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.concurrent.util.ThreadAccess +import co.paralleluniverse.fibers.Fiber +import java.lang.reflect.Field + +private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true } + +private fun Fiber.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this) + +// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available. +fun Fiber.swappedOutThreadLocalValue(threadLocal: ThreadLocal): T? { + val threadLocals = swappedOutThreadLocals() + return ThreadAccess.toMap(threadLocals)[threadLocal] as T? +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 1284b45f7e..254f530df8 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -156,7 +156,10 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } val continuation = processEvent(transitionExecutor, nextEvent) when (continuation) { - is FlowContinuation.Resume -> return continuation.result + is FlowContinuation.Resume -> { + openThreadLocalWormhole() + return continuation.result + } is FlowContinuation.Throw -> { continuation.throwable.fillInStackTrace() throw continuation.throwable @@ -208,13 +211,21 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, MDC.put("thread-id", Thread.currentThread().id.toString()) } + private fun openThreadLocalWormhole() { + val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal + if (threadLocal != null) { + val valueFromThread = swappedOutThreadLocalValue(threadLocal) + if (valueFromThread != null) threadLocal.set(valueFromThread) + } + } + @Suspendable override fun run() { logic.progressTracker?.currentStep = ProgressTracker.STARTING logic.stateMachine = this + openThreadLocalWormhole() setLoggingContext() - initialiseFlow() logger.debug { "Calling flow: $logic" }