mirror of
https://github.com/corda/corda.git
synced 2025-01-21 03:55:00 +00:00
CORDA-1004 Quasar-friendly ThreadLocal solution (#2594)
* Use FastThreadLocalThread in fiber scheduler * Test that thread locals aren't serialized
This commit is contained in:
parent
f7c9f0d10e
commit
3391810101
165
core/src/test/kotlin/net/corda/core/flows/FastThreadLocalTest.kt
Normal file
165
core/src/test/kotlin/net/corda/core/flows/FastThreadLocalTest.kt
Normal file
@ -0,0 +1,165 @@
|
||||
package net.corda.core.flows
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.FiberExecutorScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.io.serialization.ByteArraySerializer
|
||||
import co.paralleluniverse.strands.SuspendableCallable
|
||||
import io.netty.util.concurrent.FastThreadLocal
|
||||
import io.netty.util.concurrent.FastThreadLocalThread
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.internal.rootCause
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import org.assertj.core.api.Assertions.catchThrowable
|
||||
import org.hamcrest.Matchers.lessThanOrEqualTo
|
||||
import org.junit.After
|
||||
import org.junit.Assert.assertThat
|
||||
import org.junit.Test
|
||||
import java.util.*
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotNull
|
||||
|
||||
class FastThreadLocalTest {
|
||||
private inner class ExpensiveObj {
|
||||
init {
|
||||
expensiveObjCount.andIncrement
|
||||
}
|
||||
}
|
||||
|
||||
private val expensiveObjCount = AtomicInteger()
|
||||
private lateinit var pool: ExecutorService
|
||||
private lateinit var scheduler: FiberExecutorScheduler
|
||||
private fun init(threadCount: Int, threadImpl: (Runnable) -> Thread) {
|
||||
pool = Executors.newFixedThreadPool(threadCount, threadImpl)
|
||||
scheduler = FiberExecutorScheduler(null, pool)
|
||||
}
|
||||
|
||||
@After
|
||||
fun poolShutdown() = try {
|
||||
pool.shutdown()
|
||||
} catch (e: UninitializedPropertyAccessException) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
@After
|
||||
fun schedulerShutdown() = try {
|
||||
scheduler.shutdown()
|
||||
} catch (e: UninitializedPropertyAccessException) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ThreadLocal with plain old Thread is fiber-local`() {
|
||||
init(3, ::Thread)
|
||||
val threadLocal = object : ThreadLocal<ExpensiveObj>() {
|
||||
override fun initialValue() = ExpensiveObj()
|
||||
}
|
||||
assertEquals(0, runFibers(100, threadLocal::get))
|
||||
assertEquals(100, expensiveObjCount.get())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ThreadLocal with FastThreadLocalThread is fiber-local`() {
|
||||
init(3, ::FastThreadLocalThread)
|
||||
val threadLocal = object : ThreadLocal<ExpensiveObj>() {
|
||||
override fun initialValue() = ExpensiveObj()
|
||||
}
|
||||
assertEquals(0, runFibers(100, threadLocal::get))
|
||||
assertEquals(100, expensiveObjCount.get())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `FastThreadLocal with plain old Thread is fiber-local`() {
|
||||
init(3, ::Thread)
|
||||
val threadLocal = object : FastThreadLocal<ExpensiveObj>() {
|
||||
override fun initialValue() = ExpensiveObj()
|
||||
}
|
||||
assertEquals(0, runFibers(100, threadLocal::get))
|
||||
assertEquals(100, expensiveObjCount.get())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `FastThreadLocal with FastThreadLocalThread is not fiber-local`() {
|
||||
init(3, ::FastThreadLocalThread)
|
||||
val threadLocal = object : FastThreadLocal<ExpensiveObj>() {
|
||||
override fun initialValue() = ExpensiveObj()
|
||||
}
|
||||
runFibers(100, threadLocal::get) // Return value could be anything.
|
||||
assertThat(expensiveObjCount.get(), lessThanOrEqualTo(3))
|
||||
}
|
||||
|
||||
/** @return the number of times a different expensive object was obtained post-suspend. */
|
||||
private fun runFibers(fiberCount: Int, threadLocalGet: () -> ExpensiveObj): Int {
|
||||
val fibers = (0 until fiberCount).map { Fiber(scheduler, FiberTask(threadLocalGet)) }
|
||||
val startedFibers = fibers.map { it.start() }
|
||||
return startedFibers.map { it.get() }.count { it }
|
||||
}
|
||||
|
||||
private class FiberTask(private val threadLocalGet: () -> ExpensiveObj) : SuspendableCallable<Boolean> {
|
||||
@Suspendable
|
||||
override fun run(): Boolean {
|
||||
val first = threadLocalGet()
|
||||
Fiber.sleep(1)
|
||||
return threadLocalGet() != first
|
||||
}
|
||||
}
|
||||
|
||||
private class UnserializableObj {
|
||||
@Suppress("unused")
|
||||
private val fail: Nothing by lazy { throw UnsupportedOperationException("Nice try.") }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ThreadLocal content is not serialized`() {
|
||||
contentIsNotSerialized(object : ThreadLocal<UnserializableObj>() {
|
||||
override fun initialValue() = UnserializableObj()
|
||||
}::get)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `FastThreadLocal content is not serialized`() {
|
||||
contentIsNotSerialized(object : FastThreadLocal<UnserializableObj>() {
|
||||
override fun initialValue() = UnserializableObj()
|
||||
}::get)
|
||||
}
|
||||
|
||||
private fun contentIsNotSerialized(threadLocalGet: () -> UnserializableObj) {
|
||||
init(1, ::FastThreadLocalThread)
|
||||
// Use false like AbstractKryoSerializationScheme, the default of true doesn't work at all:
|
||||
val serializer = Fiber.getFiberSerializer(false)
|
||||
val returnValue = UUID.randomUUID()
|
||||
val deserializedFiber = serializer.read(openFuture<ByteArray>().let {
|
||||
Fiber(scheduler, FiberTask2(threadLocalGet, false, serializer, it, returnValue)).start()
|
||||
it.getOrThrow()
|
||||
}) as Fiber<*>
|
||||
assertEquals(returnValue, Fiber.unparkDeserialized(deserializedFiber, scheduler).get())
|
||||
assertEquals("Nice try.", openFuture<ByteArray>().let {
|
||||
Fiber(scheduler, FiberTask2(threadLocalGet, true, serializer, it, returnValue)).start()
|
||||
catchThrowable { it.getOrThrow() }
|
||||
}.rootCause.message)
|
||||
}
|
||||
|
||||
private class FiberTask2(
|
||||
@Transient private val threadLocalGet: () -> UnserializableObj,
|
||||
private val retainObj: Boolean,
|
||||
@Transient private val serializer: ByteArraySerializer,
|
||||
@Transient private val bytesFuture: OpenFuture<ByteArray>,
|
||||
private val returnValue: UUID) : SuspendableCallable<UUID> {
|
||||
@Suspendable
|
||||
override fun run(): UUID {
|
||||
var obj: UnserializableObj? = threadLocalGet()
|
||||
assertNotNull(obj)
|
||||
if (!retainObj) {
|
||||
@Suppress("UNUSED_VALUE")
|
||||
obj = null
|
||||
}
|
||||
// In retainObj false case, check this doesn't attempt to serialize fields of currentThread:
|
||||
Fiber.parkAndSerialize { fiber, _ -> bytesFuture.capture { serializer.write(fiber) } }
|
||||
return returnValue
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package net.corda.node.utilities
|
||||
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import io.netty.util.concurrent.FastThreadLocalThread
|
||||
import java.util.*
|
||||
import java.util.concurrent.CompletableFuture
|
||||
import java.util.concurrent.Executor
|
||||
@ -55,8 +56,8 @@ interface AffinityExecutor : Executor {
|
||||
private val threads = Collections.synchronizedSet(HashSet<Thread>())
|
||||
|
||||
init {
|
||||
setThreadFactory(fun(runnable: Runnable): Thread {
|
||||
val thread = object : Thread() {
|
||||
setThreadFactory { runnable ->
|
||||
val thread = object : FastThreadLocalThread() {
|
||||
override fun run() {
|
||||
try {
|
||||
runnable.run()
|
||||
@ -68,8 +69,8 @@ interface AffinityExecutor : Executor {
|
||||
thread.isDaemon = true
|
||||
thread.name = threadName
|
||||
threads += thread
|
||||
return thread
|
||||
})
|
||||
thread
|
||||
}
|
||||
}
|
||||
|
||||
override val isOnThread: Boolean get() = Thread.currentThread() in threads
|
||||
|
@ -20,6 +20,7 @@ dependencies {
|
||||
|
||||
// Unit testing helpers.
|
||||
compile "junit:junit:$junit_version"
|
||||
compile 'org.hamcrest:hamcrest-library:1.3'
|
||||
compile "com.nhaarman:mockito-kotlin:1.1.0"
|
||||
|
||||
// Guava: Google test library (collections test suite)
|
||||
|
Loading…
Reference in New Issue
Block a user