CORDA-2748 Always set the ThreadLocal in the Fiber from the Thread, even if not yet set in the Thread. (#4896)

This commit is contained in:
Rick Parker 2019-03-18 11:51:08 +00:00 committed by GitHub
parent fb4dc0a6ac
commit 31100cd708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -3,13 +3,19 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.concurrent.util.ThreadAccess import co.paralleluniverse.concurrent.util.ThreadAccess
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import java.lang.reflect.Field import java.lang.reflect.Field
import java.lang.reflect.Method
private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true } private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true }
private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this) private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this)
private val threadLocalInitialValueMethod: Method = ThreadLocal::class.java.getDeclaredMethod("initialValue")
.apply { this.isAccessible = true }
private fun <T> ThreadLocal<T>.initialValue(): T? = threadLocalInitialValueMethod.invoke(this) as T?
// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available. // TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available.
fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? { fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? {
val threadLocals = swappedOutThreadLocals() val threadLocals = swappedOutThreadLocals()
return ThreadAccess.toMap(threadLocals)[threadLocal] as T? return (ThreadAccess.toMap(threadLocals)[threadLocal] as T?) ?: threadLocal.initialValue()
} }

View File

@ -157,7 +157,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val continuation = processEvent(transitionExecutor, nextEvent) val continuation = processEvent(transitionExecutor, nextEvent)
when (continuation) { when (continuation) {
is FlowContinuation.Resume -> { is FlowContinuation.Resume -> {
openThreadLocalWormhole()
return continuation.result return continuation.result
} }
is FlowContinuation.Throw -> { is FlowContinuation.Throw -> {
@ -170,6 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
} finally { } finally {
checkDbTransaction(isDbTransactionOpenOnExit) checkDbTransaction(isDbTransactionOpenOnExit)
openThreadLocalWormhole()
} }
} }
@ -215,7 +215,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
if (threadLocal != null) { if (threadLocal != null) {
val valueFromThread = swappedOutThreadLocalValue(threadLocal) val valueFromThread = swappedOutThreadLocalValue(threadLocal)
if (valueFromThread != null) threadLocal.set(valueFromThread) threadLocal.set(valueFromThread)
} }
} }