CORDA-806 Remove initialiseSerialization from rpcDriver (#2084)

and fix a leak or two
This commit is contained in:
Andrzej Cichocki 2017-11-29 17:42:39 +00:00 committed by GitHub
parent 2525fb52be
commit 3c31fdf31d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 117 additions and 46 deletions

View File

@ -12,11 +12,14 @@ import net.corda.core.serialization.serialize
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCApi
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.driver.poll import net.corda.testing.driver.poll
import net.corda.testing.internal.* import net.corda.testing.internal.*
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.After
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue import org.junit.Assert.assertTrue
import org.junit.Rule
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -26,6 +29,14 @@ import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
class RPCStabilityTests { class RPCStabilityTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
@After
fun shutdown() {
pool.shutdown()
}
object DummyOps : RPCOps { object DummyOps : RPCOps {
override val protocolVersion = 0 override val protocolVersion = 0
@ -197,9 +208,9 @@ class RPCStabilityTests {
val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get() val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
// Leak many observables // Leak many observables
val N = 200 val N = 200
(1..N).toList().parallelStream().forEach { (1..N).map {
proxy.leakObservable() pool.fork { proxy.leakObservable(); Unit }
} }.transpose().getOrThrow()
// In a loop force GC and check whether the server is notified // In a loop force GC and check whether the server is notified
while (true) { while (true) {
System.gc() System.gc()
@ -231,7 +242,7 @@ class RPCStabilityTests {
assertEquals("pong", client.ping()) assertEquals("pong", client.ping())
serverFollower.shutdown() serverFollower.shutdown()
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow() startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
val pingFuture = ForkJoinPool.commonPool().fork(client::ping) val pingFuture = pool.fork(client::ping)
assertEquals("pong", pingFuture.getOrThrow(10.seconds)) assertEquals("pong", pingFuture.getOrThrow(10.seconds))
clientFollower.shutdown() // Driver would do this after the new server, causing hang. clientFollower.shutdown() // Driver would do this after the new server, causing hang.
} }

View File

@ -6,14 +6,20 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.User import net.corda.nodeapi.User
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.internal.RPCDriverExposedDSLInterface import net.corda.testing.internal.RPCDriverExposedDSLInterface
import net.corda.testing.internal.rpcTestUser import net.corda.testing.internal.rpcTestUser
import net.corda.testing.internal.startInVmRpcClient import net.corda.testing.internal.startInVmRpcClient
import net.corda.testing.internal.startRpcClient import net.corda.testing.internal.startRpcClient
import org.apache.activemq.artemis.api.core.client.ClientSession import org.apache.activemq.artemis.api.core.client.ClientSession
import org.junit.Rule
import org.junit.runners.Parameterized import org.junit.runners.Parameterized
open class AbstractRPCTest { open class AbstractRPCTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)
enum class RPCTestMode { enum class RPCTestMode {
InVm, InVm,
Netty Netty

View File

@ -5,19 +5,22 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.internal.concurrent.fork import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.internal.RPCDriverExposedDSLInterface import net.corda.testing.internal.RPCDriverExposedDSLInterface
import net.corda.testing.internal.rpcDriver import net.corda.testing.internal.rpcDriver
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet
import org.junit.After
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Parameterized import org.junit.runners.Parameterized
import rx.Observable import rx.Observable
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ForkJoinPool
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class RPCConcurrencyTests : AbstractRPCTest() { class RPCConcurrencyTests : AbstractRPCTest() {
@ -36,7 +39,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
} }
class TestOpsImpl : TestOps { class TestOpsImpl(private val pool: Executor) : TestOps {
private val latches = ConcurrentHashMap<Long, CountDownLatch>() private val latches = ConcurrentHashMap<Long, CountDownLatch>()
override val protocolVersion = 0 override val protocolVersion = 0
@ -68,24 +71,22 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val branches = if (depth == 0) { val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>() Observable.empty<ObservableRose<Int>>()
} else { } else {
val publish = UnicastSubject.create<ObservableRose<Int>>() UnicastSubject.create<ObservableRose<Int>>().also { publish ->
ForkJoinPool.commonPool().fork { (1..branchingFactor).map {
(1..branchingFactor).toList().parallelStream().forEach { pool.fork { publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) }
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) }.transpose().then {
it.getOrThrow()
publish.onCompleted()
} }
publish.onCompleted()
} }
publish
} }
return ObservableRose(depth, branches) return ObservableRose(depth, branches)
} }
} }
private lateinit var testOpsImpl: TestOpsImpl
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> { private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
testOpsImpl = TestOpsImpl()
return testProxy<TestOps>( return testProxy<TestOps>(
testOpsImpl, TestOpsImpl(pool),
clientConfiguration = RPCClientConfiguration.default.copy( clientConfiguration = RPCClientConfiguration.default.copy(
reapInterval = 100.millis, reapInterval = 100.millis,
cacheConcurrencyLevel = 16 cacheConcurrencyLevel = 16
@ -96,6 +97,12 @@ class RPCConcurrencyTests : AbstractRPCTest() {
) )
} }
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
@After
fun shutdown() {
pool.shutdown()
}
@Test @Test
fun `call multiple RPCs in parallel`() { fun `call multiple RPCs in parallel`() {
rpcDriver { rpcDriver {
@ -103,19 +110,17 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val numberOfBlockedCalls = 2 val numberOfBlockedCalls = 2
val numberOfDownsRequired = 100 val numberOfDownsRequired = 100
val id = proxy.ops.newLatch(numberOfDownsRequired) val id = proxy.ops.newLatch(numberOfDownsRequired)
val done = CountDownLatch(numberOfBlockedCalls)
// Start a couple of blocking RPC calls // Start a couple of blocking RPC calls
(1..numberOfBlockedCalls).forEach { val done = (1..numberOfBlockedCalls).map {
ForkJoinPool.commonPool().fork { pool.fork {
proxy.ops.waitLatch(id) proxy.ops.waitLatch(id)
done.countDown()
} }
} }.transpose()
// Down the latch that the others are waiting for concurrently // Down the latch that the others are waiting for concurrently
(1..numberOfDownsRequired).toList().parallelStream().forEach { (1..numberOfDownsRequired).map {
proxy.ops.downLatch(id) pool.fork { proxy.ops.downLatch(id) }
} }.transpose().getOrThrow()
done.await() done.getOrThrow()
} }
} }
@ -146,7 +151,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
fun ObservableRose<Int>.subscribeToAll() { fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown() remainingLatch.countDown()
this.branches.subscribe { tree -> this.branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach { (tree.value + 1 until treeDepth).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" } require(it in depthsSeen) { "Got ${tree.value} before $it" }
} }
depthsSeen.add(tree.value) depthsSeen.add(tree.value)
@ -165,11 +170,11 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val treeDepth = 2 val treeDepth = 2
val treeBranchingFactor = 10 val treeBranchingFactor = 10
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1)) val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>()) val depthsSeen = ConcurrentHashSet<Int>()
fun ObservableRose<Int>.subscribeToAll() { fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown() remainingLatch.countDown()
branches.subscribe { tree -> branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach { (tree.value + 1 until treeDepth).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" } require(it in depthsSeen) { "Got ${tree.value} before $it" }
} }
depthsSeen.add(tree.value) depthsSeen.add(tree.value)

View File

@ -5,12 +5,18 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.internal.rpcDriver import net.corda.testing.internal.rpcDriver
import net.corda.testing.internal.startRpcClient import net.corda.testing.internal.startRpcClient
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Rule
import org.junit.Test import org.junit.Test
class RPCFailureTests { class RPCFailureTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule(true)
class Unserializable class Unserializable
interface Ops : RPCOps { interface Ops : RPCOps {
fun getUnserializable(): Unserializable fun getUnserializable(): Unserializable

View File

@ -31,6 +31,8 @@ import java.time.Duration
import java.time.temporal.Temporal import java.time.temporal.Temporal
import java.util.* import java.util.*
import java.util.Spliterator.* import java.util.Spliterator.*
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import java.util.stream.IntStream import java.util.stream.IntStream
import java.util.stream.Stream import java.util.stream.Stream
import java.util.stream.StreamSupport import java.util.stream.StreamSupport
@ -307,3 +309,10 @@ fun TransactionBuilder.toLedgerTransaction(services: ServiceHub, serializationCo
val KClass<*>.packageName: String get() = java.`package`.name val KClass<*>.packageName: String get() = java.`package`.name
fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection
/** Analogous to [Thread.join]. */
fun ExecutorService.join() {
shutdown() // Do not change to shutdownNow, tests use this method to assert the executor has no more tasks.
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Try forever. Do not give up, tests use this method to assert the executor has no more tasks.
}
}

View File

@ -133,7 +133,7 @@ class ContractUpgradeFlowTest {
@Test @Test
fun `2 parties contract upgrade using RPC`() { fun `2 parties contract upgrade using RPC`() {
rpcDriver(initialiseSerialization = false) { rpcDriver {
// Create dummy contract. // Create dummy contract.
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1)) val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1))
val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract) val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract)

View File

@ -14,7 +14,6 @@ import org.junit.runners.model.Statement
import org.slf4j.Logger import org.slf4j.Logger
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNull import kotlin.test.assertNull
@ -23,10 +22,7 @@ private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Ex
fork {}.getOrThrow() // Start the thread. fork {}.getOrThrow() // Start the thread.
callable() callable()
} finally { } finally {
shutdown() join()
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
} }
} }

View File

@ -2,13 +2,13 @@ package net.corda.core.internal.concurrent
import com.nhaarman.mockito_kotlin.* import com.nhaarman.mockito_kotlin.*
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.join
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.testing.rigorousMock import net.corda.testing.rigorousMock
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Test import org.junit.Test
import org.slf4j.Logger import org.slf4j.Logger
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
@ -108,10 +108,7 @@ class CordaFutureTest {
val throwable = Exception("Boom") val throwable = Exception("Boom")
val executor = Executors.newSingleThreadExecutor() val executor = Executors.newSingleThreadExecutor()
executor.fork { throw throwable }.andForget(log) executor.fork { throw throwable }.andForget(log)
executor.shutdown() executor.join()
while (!executor.awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
verify(log).error(any(), same(throwable)) verify(log).error(any(), same(throwable))
} }

View File

@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LazyStickyPool
import net.corda.core.internal.LifeCycle import net.corda.core.internal.LifeCycle
import net.corda.core.internal.join
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
@ -207,6 +208,7 @@ class RPCServer(
} }
fun close() { fun close() {
observationSendExecutor?.join()
reaperScheduledFuture?.cancel(false) reaperScheduledFuture?.cancel(false)
rpcExecutor?.shutdownNow() rpcExecutor?.shutdownNow()
reaperExecutor?.shutdownNow() reaperExecutor?.shutdownNow()

View File

@ -230,7 +230,6 @@ fun <A> rpcDriver(
debugPortAllocation: PortAllocation = globalDebugPortAllocation, debugPortAllocation: PortAllocation = globalDebugPortAllocation,
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
initialiseSerialization: Boolean = true,
startNodesInProcess: Boolean = false, startNodesInProcess: Boolean = false,
waitForNodesToFinish: Boolean = false, waitForNodesToFinish: Boolean = false,
extraCordappPackagesToScan: List<String> = emptyList(), extraCordappPackagesToScan: List<String> = emptyList(),
@ -254,7 +253,7 @@ fun <A> rpcDriver(
), ),
coerce = { it }, coerce = { it },
dsl = dsl, dsl = dsl,
initialiseSerialization = initialiseSerialization initialiseSerialization = false
) )
private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 { private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {

View File

@ -1,26 +1,50 @@
package net.corda.testing package net.corda.testing
import com.nhaarman.mockito_kotlin.doNothing import com.nhaarman.mockito_kotlin.*
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.KryoClientSerializationScheme import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.core.internal.staticField
import net.corda.core.serialization.internal.* import net.corda.core.serialization.internal.*
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.testing.common.internal.asContextEnv import net.corda.testing.common.internal.asContextEnv
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
import org.junit.rules.TestRule import org.junit.rules.TestRule
import org.junit.runner.Description import org.junit.runner.Description
import org.junit.runners.model.Statement import org.junit.runners.model.Statement
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
private val inVMExecutors = ConcurrentHashMap<SerializationEnvironment, ExecutorService>()
/** @param inheritable whether new threads inherit the environment, use sparingly. */ /** @param inheritable whether new threads inherit the environment, use sparingly. */
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule { class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
companion object {
init {
// Can't turn it off, and it creates threads that do serialization, so hack it:
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
doAnswer {
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
}.execute(it.arguments[0] as Runnable)
}.whenever(it).execute(any())
}
}
}
lateinit var env: SerializationEnvironment lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement { override fun apply(base: Statement, description: Description): Statement {
env = createTestSerializationEnv(description.toString()) env = createTestSerializationEnv(description.toString())
return object : Statement() { return object : Statement() {
override fun evaluate() = env.asContextEnv(inheritable) { override fun evaluate() {
base.evaluate() try {
env.asContextEnv(inheritable) { base.evaluate() }
} finally {
inVMExecutors.remove(env)
}
} }
} }
} }
@ -59,6 +83,7 @@ fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") { object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
override fun unset() { override fun unset() {
_globalSerializationEnv.set(null) _globalSerializationEnv.set(null)
inVMExecutors.remove(this)
} }
}.also { }.also {
_globalSerializationEnv.set(it) _globalSerializationEnv.set(it)

View File

@ -0,0 +1,15 @@
package net.corda.testing.internal
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger
private val familyToNextPoolNumber = ConcurrentHashMap<String, AtomicInteger>()
fun Any.testThreadFactory(useEnclosingClassName: Boolean = false): ThreadFactory {
val poolFamily = javaClass.let { (if (useEnclosingClassName) it.enclosingClass else it).simpleName }
val poolNumber = familyToNextPoolNumber.computeIfAbsent(poolFamily) { AtomicInteger(1) }.getAndIncrement()
val nextThreadNumber = AtomicInteger(1)
return ThreadFactory { task ->
Thread(task, "$poolFamily-$poolNumber-${nextThreadNumber.getAndIncrement()}")
}
}