mirror of
https://github.com/corda/corda.git
synced 2025-01-31 00:24:59 +00:00
CORDA-2748 Always set the ThreadLocal in the Fiber from the Thread, even if not yet set in the Thread. (#4896)
This commit is contained in:
parent
fb4dc0a6ac
commit
31100cd708
@ -3,13 +3,19 @@ package net.corda.node.services.statemachine
|
|||||||
import co.paralleluniverse.concurrent.util.ThreadAccess
|
import co.paralleluniverse.concurrent.util.ThreadAccess
|
||||||
import co.paralleluniverse.fibers.Fiber
|
import co.paralleluniverse.fibers.Fiber
|
||||||
import java.lang.reflect.Field
|
import java.lang.reflect.Field
|
||||||
|
import java.lang.reflect.Method
|
||||||
|
|
||||||
private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true }
|
private val fiberThreadLocalsField: Field = Fiber::class.java.getDeclaredField("fiberLocals").apply { this.isAccessible = true }
|
||||||
|
|
||||||
private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this)
|
private fun <V> Fiber<V>.swappedOutThreadLocals(): Any = fiberThreadLocalsField.get(this)
|
||||||
|
|
||||||
|
private val threadLocalInitialValueMethod: Method = ThreadLocal::class.java.getDeclaredMethod("initialValue")
|
||||||
|
.apply { this.isAccessible = true }
|
||||||
|
|
||||||
|
private fun <T> ThreadLocal<T>.initialValue(): T? = threadLocalInitialValueMethod.invoke(this) as T?
|
||||||
|
|
||||||
// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available.
|
// TODO: This method uses a built-in Quasar function to make a map of all ThreadLocals. This is probably inefficient, but the only API readily available.
|
||||||
fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? {
|
fun <V, T> Fiber<V>.swappedOutThreadLocalValue(threadLocal: ThreadLocal<T>): T? {
|
||||||
val threadLocals = swappedOutThreadLocals()
|
val threadLocals = swappedOutThreadLocals()
|
||||||
return ThreadAccess.toMap(threadLocals)[threadLocal] as T?
|
return (ThreadAccess.toMap(threadLocals)[threadLocal] as T?) ?: threadLocal.initialValue()
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
val continuation = processEvent(transitionExecutor, nextEvent)
|
val continuation = processEvent(transitionExecutor, nextEvent)
|
||||||
when (continuation) {
|
when (continuation) {
|
||||||
is FlowContinuation.Resume -> {
|
is FlowContinuation.Resume -> {
|
||||||
openThreadLocalWormhole()
|
|
||||||
return continuation.result
|
return continuation.result
|
||||||
}
|
}
|
||||||
is FlowContinuation.Throw -> {
|
is FlowContinuation.Throw -> {
|
||||||
@ -170,6 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
checkDbTransaction(isDbTransactionOpenOnExit)
|
checkDbTransaction(isDbTransactionOpenOnExit)
|
||||||
|
openThreadLocalWormhole()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
|
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
|
||||||
if (threadLocal != null) {
|
if (threadLocal != null) {
|
||||||
val valueFromThread = swappedOutThreadLocalValue(threadLocal)
|
val valueFromThread = swappedOutThreadLocalValue(threadLocal)
|
||||||
if (valueFromThread != null) threadLocal.set(valueFromThread)
|
threadLocal.set(valueFromThread)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user