CORDA-2748 Always set the ThreadLocal in the Fiber from the Thread, even if not yet set in the Thread. (#4896)

This commit is contained in:
Rick Parker 2019-03-18 11:51:08 +00:00 committed by GitHub
parent fb4dc0a6ac
commit 31100cd708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -3,13 +3,19 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.concurrent.util.ThreadAccess
import co.paralleluniverse.fibers.Fiber
import java.lang.reflect.Field
import java.lang.reflect.Method
private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true }
private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this)
private val threadLocalInitialValueMethod: Method = ThreadLocal::class.java.getDeclaredMethod("initialValue")
.apply { this.isAccessible = true }
private fun <T> ThreadLocal<T>.initialValue(): T? = threadLocalInitialValueMethod.invoke(this) as T?
// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available.
fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? {
val threadLocals = swappedOutThreadLocals()
return ThreadAccess.toMap(threadLocals)[threadLocal] as T?
return (ThreadAccess.toMap(threadLocals)[threadLocal] as T?) ?: threadLocal.initialValue()
}

View File

@ -157,7 +157,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val continuation = processEvent(transitionExecutor, nextEvent)
when (continuation) {
is FlowContinuation.Resume -> {
openThreadLocalWormhole()
return continuation.result
}
is FlowContinuation.Throw -> {
@ -170,6 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
} finally {
checkDbTransaction(isDbTransactionOpenOnExit)
openThreadLocalWormhole()
}
}
@ -215,7 +215,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
if (threadLocal != null) {
val valueFromThread = swappedOutThreadLocalValue(threadLocal)
if (valueFromThread != null) threadLocal.set(valueFromThread)
threadLocal.set(valueFromThread)
}
}