mirror of
https://github.com/corda/corda.git
synced 2025-02-24 02:41:22 +00:00
CORDA-806 Remove initialiseSerialization from rpcDriver (#2084)
and fix a leak or two
This commit is contained in:
parent
2525fb52be
commit
3c31fdf31d
@ -12,11 +12,14 @@ import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.RPCApi
|
||||
import net.corda.testing.SerializationEnvironmentRule
|
||||
import net.corda.testing.driver.poll
|
||||
import net.corda.testing.internal.*
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.junit.After
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertTrue
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
@ -26,6 +29,14 @@ import java.util.concurrent.*
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
class RPCStabilityTests {
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule(true)
|
||||
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
|
||||
@After
|
||||
fun shutdown() {
|
||||
pool.shutdown()
|
||||
}
|
||||
|
||||
object DummyOps : RPCOps {
|
||||
override val protocolVersion = 0
|
||||
@ -197,9 +208,9 @@ class RPCStabilityTests {
|
||||
val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
|
||||
// Leak many observables
|
||||
val N = 200
|
||||
(1..N).toList().parallelStream().forEach {
|
||||
proxy.leakObservable()
|
||||
}
|
||||
(1..N).map {
|
||||
pool.fork { proxy.leakObservable(); Unit }
|
||||
}.transpose().getOrThrow()
|
||||
// In a loop force GC and check whether the server is notified
|
||||
while (true) {
|
||||
System.gc()
|
||||
@ -231,7 +242,7 @@ class RPCStabilityTests {
|
||||
assertEquals("pong", client.ping())
|
||||
serverFollower.shutdown()
|
||||
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
|
||||
val pingFuture = ForkJoinPool.commonPool().fork(client::ping)
|
||||
val pingFuture = pool.fork(client::ping)
|
||||
assertEquals("pong", pingFuture.getOrThrow(10.seconds))
|
||||
clientFollower.shutdown() // Driver would do this after the new server, causing hang.
|
||||
}
|
||||
|
@ -6,14 +6,20 @@ import net.corda.core.internal.concurrent.map
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.User
|
||||
import net.corda.testing.SerializationEnvironmentRule
|
||||
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.internal.rpcTestUser
|
||||
import net.corda.testing.internal.startInVmRpcClient
|
||||
import net.corda.testing.internal.startRpcClient
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.junit.Rule
|
||||
import org.junit.runners.Parameterized
|
||||
|
||||
open class AbstractRPCTest {
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule(true)
|
||||
|
||||
enum class RPCTestMode {
|
||||
InVm,
|
||||
Netty
|
||||
|
@ -5,19 +5,22 @@ import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.utilities.millis
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.internal.concurrent.fork
|
||||
import net.corda.core.internal.concurrent.transpose
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.internal.rpcDriver
|
||||
import net.corda.testing.internal.testThreadFactory
|
||||
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import rx.Observable
|
||||
import rx.subjects.UnicastSubject
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.ForkJoinPool
|
||||
import java.util.concurrent.*
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
@ -36,7 +39,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
|
||||
}
|
||||
|
||||
class TestOpsImpl : TestOps {
|
||||
class TestOpsImpl(private val pool: Executor) : TestOps {
|
||||
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
|
||||
override val protocolVersion = 0
|
||||
|
||||
@ -68,24 +71,22 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
val branches = if (depth == 0) {
|
||||
Observable.empty<ObservableRose<Int>>()
|
||||
} else {
|
||||
val publish = UnicastSubject.create<ObservableRose<Int>>()
|
||||
ForkJoinPool.commonPool().fork {
|
||||
(1..branchingFactor).toList().parallelStream().forEach {
|
||||
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
|
||||
UnicastSubject.create<ObservableRose<Int>>().also { publish ->
|
||||
(1..branchingFactor).map {
|
||||
pool.fork { publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) }
|
||||
}.transpose().then {
|
||||
it.getOrThrow()
|
||||
publish.onCompleted()
|
||||
}
|
||||
publish.onCompleted()
|
||||
}
|
||||
publish
|
||||
}
|
||||
return ObservableRose(depth, branches)
|
||||
}
|
||||
}
|
||||
|
||||
private lateinit var testOpsImpl: TestOpsImpl
|
||||
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
|
||||
testOpsImpl = TestOpsImpl()
|
||||
return testProxy<TestOps>(
|
||||
testOpsImpl,
|
||||
TestOpsImpl(pool),
|
||||
clientConfiguration = RPCClientConfiguration.default.copy(
|
||||
reapInterval = 100.millis,
|
||||
cacheConcurrencyLevel = 16
|
||||
@ -96,6 +97,12 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
)
|
||||
}
|
||||
|
||||
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
|
||||
@After
|
||||
fun shutdown() {
|
||||
pool.shutdown()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `call multiple RPCs in parallel`() {
|
||||
rpcDriver {
|
||||
@ -103,19 +110,17 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
val numberOfBlockedCalls = 2
|
||||
val numberOfDownsRequired = 100
|
||||
val id = proxy.ops.newLatch(numberOfDownsRequired)
|
||||
val done = CountDownLatch(numberOfBlockedCalls)
|
||||
// Start a couple of blocking RPC calls
|
||||
(1..numberOfBlockedCalls).forEach {
|
||||
ForkJoinPool.commonPool().fork {
|
||||
val done = (1..numberOfBlockedCalls).map {
|
||||
pool.fork {
|
||||
proxy.ops.waitLatch(id)
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
}.transpose()
|
||||
// Down the latch that the others are waiting for concurrently
|
||||
(1..numberOfDownsRequired).toList().parallelStream().forEach {
|
||||
proxy.ops.downLatch(id)
|
||||
}
|
||||
done.await()
|
||||
(1..numberOfDownsRequired).map {
|
||||
pool.fork { proxy.ops.downLatch(id) }
|
||||
}.transpose().getOrThrow()
|
||||
done.getOrThrow()
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,7 +151,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
fun ObservableRose<Int>.subscribeToAll() {
|
||||
remainingLatch.countDown()
|
||||
this.branches.subscribe { tree ->
|
||||
(tree.value + 1..treeDepth - 1).forEach {
|
||||
(tree.value + 1 until treeDepth).forEach {
|
||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||
}
|
||||
depthsSeen.add(tree.value)
|
||||
@ -165,11 +170,11 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
val treeDepth = 2
|
||||
val treeBranchingFactor = 10
|
||||
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
|
||||
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
|
||||
val depthsSeen = ConcurrentHashSet<Int>()
|
||||
fun ObservableRose<Int>.subscribeToAll() {
|
||||
remainingLatch.countDown()
|
||||
branches.subscribe { tree ->
|
||||
(tree.value + 1..treeDepth - 1).forEach {
|
||||
(tree.value + 1 until treeDepth).forEach {
|
||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||
}
|
||||
depthsSeen.add(tree.value)
|
||||
|
@ -5,12 +5,18 @@ import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.messaging.*
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.testing.SerializationEnvironmentRule
|
||||
import net.corda.testing.internal.rpcDriver
|
||||
import net.corda.testing.internal.startRpcClient
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
|
||||
class RPCFailureTests {
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule(true)
|
||||
|
||||
class Unserializable
|
||||
interface Ops : RPCOps {
|
||||
fun getUnserializable(): Unserializable
|
||||
|
@ -31,6 +31,8 @@ import java.time.Duration
|
||||
import java.time.temporal.Temporal
|
||||
import java.util.*
|
||||
import java.util.Spliterator.*
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.stream.IntStream
|
||||
import java.util.stream.Stream
|
||||
import java.util.stream.StreamSupport
|
||||
@ -307,3 +309,10 @@ fun TransactionBuilder.toLedgerTransaction(services: ServiceHub, serializationCo
|
||||
val KClass<*>.packageName: String get() = java.`package`.name
|
||||
|
||||
fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection
|
||||
/** Analogous to [Thread.join]. */
|
||||
fun ExecutorService.join() {
|
||||
shutdown() // Do not change to shutdownNow, tests use this method to assert the executor has no more tasks.
|
||||
while (!awaitTermination(1, TimeUnit.SECONDS)) {
|
||||
// Try forever. Do not give up, tests use this method to assert the executor has no more tasks.
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ class ContractUpgradeFlowTest {
|
||||
|
||||
@Test
|
||||
fun `2 parties contract upgrade using RPC`() {
|
||||
rpcDriver(initialiseSerialization = false) {
|
||||
rpcDriver {
|
||||
// Create dummy contract.
|
||||
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1))
|
||||
val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract)
|
||||
|
@ -14,7 +14,6 @@ import org.junit.runners.model.Statement
|
||||
import org.slf4j.Logger
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNull
|
||||
|
||||
@ -23,10 +22,7 @@ private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Ex
|
||||
fork {}.getOrThrow() // Start the thread.
|
||||
callable()
|
||||
} finally {
|
||||
shutdown()
|
||||
while (!awaitTermination(1, TimeUnit.SECONDS)) {
|
||||
// Do nothing.
|
||||
}
|
||||
join()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,13 +2,13 @@ package net.corda.core.internal.concurrent
|
||||
|
||||
import com.nhaarman.mockito_kotlin.*
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.internal.join
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.testing.rigorousMock
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.Test
|
||||
import org.slf4j.Logger
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
@ -108,10 +108,7 @@ class CordaFutureTest {
|
||||
val throwable = Exception("Boom")
|
||||
val executor = Executors.newSingleThreadExecutor()
|
||||
executor.fork { throw throwable }.andForget(log)
|
||||
executor.shutdown()
|
||||
while (!executor.awaitTermination(1, TimeUnit.SECONDS)) {
|
||||
// Do nothing.
|
||||
}
|
||||
executor.join()
|
||||
verify(log).error(any(), same(throwable))
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.LazyStickyPool
|
||||
import net.corda.core.internal.LifeCycle
|
||||
import net.corda.core.internal.join
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
|
||||
@ -207,6 +208,7 @@ class RPCServer(
|
||||
}
|
||||
|
||||
fun close() {
|
||||
observationSendExecutor?.join()
|
||||
reaperScheduledFuture?.cancel(false)
|
||||
rpcExecutor?.shutdownNow()
|
||||
reaperExecutor?.shutdownNow()
|
||||
|
@ -230,7 +230,6 @@ fun <A> rpcDriver(
|
||||
debugPortAllocation: PortAllocation = globalDebugPortAllocation,
|
||||
systemProperties: Map<String, String> = emptyMap(),
|
||||
useTestClock: Boolean = false,
|
||||
initialiseSerialization: Boolean = true,
|
||||
startNodesInProcess: Boolean = false,
|
||||
waitForNodesToFinish: Boolean = false,
|
||||
extraCordappPackagesToScan: List<String> = emptyList(),
|
||||
@ -254,7 +253,7 @@ fun <A> rpcDriver(
|
||||
),
|
||||
coerce = { it },
|
||||
dsl = dsl,
|
||||
initialiseSerialization = initialiseSerialization
|
||||
initialiseSerialization = false
|
||||
)
|
||||
|
||||
private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {
|
||||
|
@ -1,26 +1,50 @@
|
||||
package net.corda.testing
|
||||
|
||||
import com.nhaarman.mockito_kotlin.doNothing
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import com.nhaarman.mockito_kotlin.*
|
||||
import net.corda.client.rpc.internal.KryoClientSerializationScheme
|
||||
import net.corda.core.internal.staticField
|
||||
import net.corda.core.serialization.internal.*
|
||||
import net.corda.node.serialization.KryoServerSerializationScheme
|
||||
import net.corda.nodeapi.internal.serialization.*
|
||||
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
|
||||
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
|
||||
import net.corda.testing.common.internal.asContextEnv
|
||||
import net.corda.testing.internal.testThreadFactory
|
||||
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
|
||||
import org.junit.rules.TestRule
|
||||
import org.junit.runner.Description
|
||||
import org.junit.runners.model.Statement
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
private val inVMExecutors = ConcurrentHashMap<SerializationEnvironment, ExecutorService>()
|
||||
|
||||
/** @param inheritable whether new threads inherit the environment, use sparingly. */
|
||||
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
|
||||
companion object {
|
||||
init {
|
||||
// Can't turn it off, and it creates threads that do serialization, so hack it:
|
||||
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
|
||||
doAnswer {
|
||||
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
|
||||
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
|
||||
}.execute(it.arguments[0] as Runnable)
|
||||
}.whenever(it).execute(any())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lateinit var env: SerializationEnvironment
|
||||
override fun apply(base: Statement, description: Description): Statement {
|
||||
env = createTestSerializationEnv(description.toString())
|
||||
return object : Statement() {
|
||||
override fun evaluate() = env.asContextEnv(inheritable) {
|
||||
base.evaluate()
|
||||
override fun evaluate() {
|
||||
try {
|
||||
env.asContextEnv(inheritable) { base.evaluate() }
|
||||
} finally {
|
||||
inVMExecutors.remove(env)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -59,6 +83,7 @@ fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
|
||||
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
|
||||
override fun unset() {
|
||||
_globalSerializationEnv.set(null)
|
||||
inVMExecutors.remove(this)
|
||||
}
|
||||
}.also {
|
||||
_globalSerializationEnv.set(it)
|
||||
|
@ -0,0 +1,15 @@
|
||||
package net.corda.testing.internal
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ThreadFactory
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
private val familyToNextPoolNumber = ConcurrentHashMap<String, AtomicInteger>()
|
||||
fun Any.testThreadFactory(useEnclosingClassName: Boolean = false): ThreadFactory {
|
||||
val poolFamily = javaClass.let { (if (useEnclosingClassName) it.enclosingClass else it).simpleName }
|
||||
val poolNumber = familyToNextPoolNumber.computeIfAbsent(poolFamily) { AtomicInteger(1) }.getAndIncrement()
|
||||
val nextThreadNumber = AtomicInteger(1)
|
||||
return ThreadFactory { task ->
|
||||
Thread(task, "$poolFamily-$poolNumber-${nextThreadNumber.getAndIncrement()}")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user