diff --git a/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt b/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt index 8e0db8303d..36ae18d05c 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt @@ -45,18 +45,21 @@ class ThreadLocalToggleField(name: String) : ToggleField(name) { /** The named thread has leaked from a previous test. */ class ThreadLeakException : RuntimeException("Leaked thread detected: ${Thread.currentThread().name}") -/** @param exceptionHandler should throw the exception, or may return normally to suppress inheritance. */ +/** @param isAGlobalThreadBeingCreated whether a global thread (that should not inherit any value) is being created. */ class InheritableThreadLocalToggleField(name: String, private val log: Logger = loggerFor>(), - private val exceptionHandler: (ThreadLeakException) -> Unit = { throw it }) : ToggleField(name) { + private val isAGlobalThreadBeingCreated: (Array) -> Boolean) : ToggleField(name) { private inner class Holder(value: T) : AtomicReference(value) { fun valueOrDeclareLeak() = get() ?: throw ThreadLeakException() fun childValue(): Holder? { - get() != null && return this // Current thread isn't leaked. - val e = ThreadLeakException() - exceptionHandler(e) - log.warn(e.message) - return null + val e = ThreadLeakException() // Expensive, but so is starting the new thread. + return if (isAGlobalThreadBeingCreated(e.stackTrace)) { + get() ?: log.warn(e.message) + null + } else { + get() ?: log.error(e.message) + this + } } } diff --git a/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt b/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt index e9a949d2e9..a97eff4c56 100644 --- a/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt +++ b/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt @@ -57,6 +57,9 @@ fun CordaFuture.flatMap(transform: (V) -> CordaFuture): Cor }) } +/** Wrap a CompletableFuture, for example one that was returned by some API. */ +fun CompletableFuture.asCordaFuture(): CordaFuture = CordaFutureImpl(this) + /** * If all of the given futures succeed, the returned future's outcome is a list of all their values. * The values are in the same order as the futures in the collection, not the order of completion. diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt index 5c77ce3bf7..06531d0c81 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt @@ -43,12 +43,12 @@ val _globalSerializationEnv = SimpleToggleField("globa @VisibleForTesting val _contextSerializationEnv = ThreadLocalToggleField("contextSerializationEnv") @VisibleForTesting -val _inheritableContextSerializationEnv = InheritableThreadLocalToggleField("inheritableContextSerializationEnv") suppressInherit@ { - it.stackTrace.forEach { - // A dying Netty thread's death event restarting the Netty global executor: - it.className == "io.netty.util.concurrent.GlobalEventExecutor" && it.methodName == "startThread" && return@suppressInherit +val _inheritableContextSerializationEnv = InheritableThreadLocalToggleField("inheritableContextSerializationEnv") { stack -> + stack.fold(false) { isAGlobalThreadBeingCreated, e -> + isAGlobalThreadBeingCreated || + (e.className == "io.netty.util.concurrent.GlobalEventExecutor" && e.methodName == "startThread") || + (e.className == "java.util.concurrent.ForkJoinPool\$DefaultForkJoinWorkerThreadFactory" && e.methodName == "newThread") } - throw it } private val serializationEnvProperties = listOf(_nodeSerializationEnv, _globalSerializationEnv, _contextSerializationEnv, _inheritableContextSerializationEnv) val effectiveSerializationEnv: SerializationEnvironment diff --git a/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt b/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt index b00e5d4e03..0967e93b8a 100644 --- a/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt +++ b/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt @@ -7,7 +7,10 @@ import com.nhaarman.mockito_kotlin.verifyNoMoreInteractions import net.corda.core.internal.concurrent.fork import net.corda.core.utilities.getOrThrow import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Rule import org.junit.Test +import org.junit.rules.TestRule +import org.junit.runners.model.Statement import org.slf4j.Logger import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -28,9 +31,34 @@ private fun withSingleThreadExecutor(callable: ExecutorService.() -> T) = Ex } class ToggleFieldTest { + companion object { + @Suppress("JAVA_CLASS_ON_COMPANION") + private val companionName = javaClass.name + + private fun globalThreadCreationMethod(task: () -> T) = task() + } + + private val log = mock() + @Rule + @JvmField + val verifyNoMoreInteractions = TestRule { base, _ -> + object : Statement() { + override fun evaluate() { + base.evaluate() + verifyNoMoreInteractions(log) // Only on success. + } + } + } + + private fun inheritableThreadLocalToggleField() = InheritableThreadLocalToggleField("inheritable", log) { stack -> + stack.fold(false) { isAGlobalThreadBeingCreated, e -> + isAGlobalThreadBeingCreated || (e.className == companionName && e.methodName == "globalThreadCreationMethod") + } + } + @Test fun `toggle is enforced`() { - listOf(SimpleToggleField("simple"), ThreadLocalToggleField("local"), InheritableThreadLocalToggleField("inheritable")).forEach { field -> + listOf(SimpleToggleField("simple"), ThreadLocalToggleField("local"), inheritableThreadLocalToggleField()).forEach { field -> assertNull(field.get()) assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java) field.set("hello") @@ -71,7 +99,7 @@ class ToggleFieldTest { @Test fun `inheritable thread local works`() { - val field = InheritableThreadLocalToggleField("field") + val field = inheritableThreadLocalToggleField() assertNull(field.get()) field.set("hello") assertEquals("hello", field.get()) @@ -84,7 +112,7 @@ class ToggleFieldTest { @Test fun `existing threads do not inherit`() { - val field = InheritableThreadLocalToggleField("field") + val field = inheritableThreadLocalToggleField() withSingleThreadExecutor { field.set("hello") assertEquals("hello", field.get()) @@ -93,16 +121,8 @@ class ToggleFieldTest { } @Test - fun `with default exception handler, inherited values are poisoned on clear`() { - `inherited values are poisoned on clear`(InheritableThreadLocalToggleField("field") { throw it }) - } - - @Test - fun `with lenient exception handler, inherited values are poisoned on clear`() { - `inherited values are poisoned on clear`(InheritableThreadLocalToggleField("field") {}) - } - - private fun `inherited values are poisoned on clear`(field: InheritableThreadLocalToggleField) { + fun `inherited values are poisoned on clear`() { + val field = inheritableThreadLocalToggleField() field.set("hello") withSingleThreadExecutor { assertEquals("hello", fork(field::get).getOrThrow()) @@ -121,39 +141,70 @@ class ToggleFieldTest { } } + /** We log an error rather than failing-fast as the new thread may be an undetected global. */ @Test - fun `with default exception handler, leaked thread is detected as soon as it tries to create another`() { - val field = InheritableThreadLocalToggleField("field") { throw it } + fun `leaked thread propagates holder to non-global thread, with error`() { + val field = inheritableThreadLocalToggleField() field.set("hello") withSingleThreadExecutor { assertEquals("hello", fork(field::get).getOrThrow()) field.set(null) // The executor thread is now considered leaked. - val threadName = fork { Thread.currentThread().name }.getOrThrow() - val future = fork(::Thread) - assertThatThrownBy { future.getOrThrow() } - .isInstanceOf(ThreadLeakException::class.java) - .hasMessageContaining(threadName) + fork { + val leakedThreadName = Thread.currentThread().name + verifyNoMoreInteractions(log) + withSingleThreadExecutor { + // If ThreadLeakException is seen in practice, these errors form a trail of where the holder has been: + verify(log).error(argThat { contains(leakedThreadName) }) + val newThreadName = fork { Thread.currentThread().name }.getOrThrow() + val future = fork(field::get) + assertThatThrownBy { future.getOrThrow() } + .isInstanceOf(ThreadLeakException::class.java) + .hasMessageContaining(newThreadName) + fork { + verifyNoMoreInteractions(log) + withSingleThreadExecutor { + verify(log).error(argThat { contains(newThreadName) }) + } + }.getOrThrow() + } + }.getOrThrow() } } @Test - fun `with lenient exception handler, leaked thread logs a warning and does not propagate the holder`() { - val log = mock() - val field = InheritableThreadLocalToggleField("field", log) {} + fun `leaked thread does not propagate holder to global thread, with warning`() { + val field = inheritableThreadLocalToggleField() field.set("hello") withSingleThreadExecutor { assertEquals("hello", fork(field::get).getOrThrow()) field.set(null) // The executor thread is now considered leaked. - val threadName = fork { Thread.currentThread().name }.getOrThrow() fork { - verifyNoMoreInteractions(log) - withSingleThreadExecutor { - verify(log).warn(argThat { contains(threadName) }) - // In practice the new thread is for example a static thread we can't get rid of: - assertNull(fork(field::get).getOrThrow()) + val leakedThreadName = Thread.currentThread().name + globalThreadCreationMethod { + verifyNoMoreInteractions(log) + withSingleThreadExecutor { + verify(log).warn(argThat { contains(leakedThreadName) }) + // In practice the new thread is for example a static thread we can't get rid of: + assertNull(fork(field::get).getOrThrow()) + } + } + }.getOrThrow() + } + } + + @Test + fun `non-leaked thread does not propagate holder to global thread, without warning`() { + val field = inheritableThreadLocalToggleField() + field.set("hello") + withSingleThreadExecutor { + fork { + assertEquals("hello", field.get()) + globalThreadCreationMethod { + withSingleThreadExecutor { + assertNull(fork(field::get).getOrThrow()) + } } }.getOrThrow() } - verifyNoMoreInteractions(log) } } diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 25b931fb76..3dda01deb1 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -40,10 +40,7 @@ import net.corda.testing.node.MockServices.Companion.makeTestDataSourcePropertie import net.corda.testing.node.MockServices.Companion.makeTestDatabaseProperties import net.corda.testing.node.MockServices.Companion.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat -import org.junit.After -import org.junit.Before -import org.junit.Rule -import org.junit.Test +import org.junit.* import java.nio.file.Paths import java.security.PublicKey import java.time.Clock @@ -77,7 +74,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { private lateinit var smmHasRemovedAllFlows: CountDownLatch private lateinit var kms: MockKeyManagementService private lateinit var mockSMM: StateMachineManager - private val ourIdentity = ALICE_NAME var calls: Int = 0 /** @@ -132,6 +128,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { } } + private var allowedUnsuspendedFiberCount = 0 @After fun tearDown() { // We need to make sure the StateMachineManager is done before shutting down executors. @@ -141,6 +138,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { smmExecutor.shutdown() smmExecutor.awaitTermination(60, TimeUnit.SECONDS) database.close() + mockSMM.stop(allowedUnsuspendedFiberCount) } // Ignore IntelliJ when it says these properties can be private, if they are we cannot serialise them @@ -224,6 +222,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { @Test fun `test activity due in the future and schedule another later`() { + allowedUnsuspendedFiberCount = 1 val time = stoppedClock.instant() + 1.days scheduleTX(time) diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt index e95243479a..9de048310e 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt @@ -6,6 +6,8 @@ import io.atomix.copycat.client.CopycatClient import io.atomix.copycat.server.CopycatServer import io.atomix.copycat.server.storage.Storage import io.atomix.copycat.server.storage.StorageLevel +import net.corda.core.internal.concurrent.asCordaFuture +import net.corda.core.internal.concurrent.transpose import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.getOrThrow import net.corda.node.utilities.CordaPersistence @@ -17,10 +19,7 @@ import net.corda.testing.freeLocalHostAndPort import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDatabaseProperties import net.corda.testing.node.MockServices.Companion.makeTestIdentityService -import org.junit.After -import org.junit.Before -import org.junit.Rule -import org.junit.Test +import org.junit.* import java.util.concurrent.CompletableFuture import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -44,10 +43,8 @@ class DistributedImmutableMapTests { @After fun tearDown() { LogHelper.reset("org.apache.activemq") - cluster.forEach { - it.client.close() - it.server.shutdown() - } + cluster.map { it.client.close().asCordaFuture() }.transpose().getOrThrow() + cluster.map { it.server.shutdown().asCordaFuture() }.transpose().getOrThrow() databases.forEach { it.close() } }