mirror of
https://github.com/corda/corda.git
synced 2025-01-18 02:39:51 +00:00
CORDA-2748 Always set the ThreadLocal in the Fiber from the Thread, even if not yet set in the Thread. (#4896)
This commit is contained in:
parent
fb4dc0a6ac
commit
31100cd708
@ -3,13 +3,19 @@ package net.corda.node.services.statemachine
|
||||
import co.paralleluniverse.concurrent.util.ThreadAccess
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import java.lang.reflect.Field
|
||||
import java.lang.reflect.Method
|
||||
|
||||
private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true }
|
||||
|
||||
private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this)
|
||||
|
||||
private val threadLocalInitialValueMethod: Method = ThreadLocal::class.java.getDeclaredMethod("initialValue")
|
||||
.apply { this.isAccessible = true }
|
||||
|
||||
private fun <T> ThreadLocal<T>.initialValue(): T? = threadLocalInitialValueMethod.invoke(this) as T?
|
||||
|
||||
// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available.
|
||||
fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? {
|
||||
val threadLocals = swappedOutThreadLocals()
|
||||
return ThreadAccess.toMap(threadLocals)[threadLocal] as T?
|
||||
return (ThreadAccess.toMap(threadLocals)[threadLocal] as T?) ?: threadLocal.initialValue()
|
||||
}
|
||||
|
@ -157,7 +157,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
val continuation = processEvent(transitionExecutor, nextEvent)
|
||||
when (continuation) {
|
||||
is FlowContinuation.Resume -> {
|
||||
openThreadLocalWormhole()
|
||||
return continuation.result
|
||||
}
|
||||
is FlowContinuation.Throw -> {
|
||||
@ -170,6 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
} finally {
|
||||
checkDbTransaction(isDbTransactionOpenOnExit)
|
||||
openThreadLocalWormhole()
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
|
||||
if (threadLocal != null) {
|
||||
val valueFromThread = swappedOutThreadLocalValue(threadLocal)
|
||||
if (valueFromThread != null) threadLocal.set(valueFromThread)
|
||||
threadLocal.set(valueFromThread)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user