mirror of
https://github.com/corda/corda.git
synced 2025-02-24 02:41:22 +00:00
CORDA-806 Remove initialiseSerialization from rpcDriver (#2084)
and fix a leak or two
This commit is contained in:
parent
2525fb52be
commit
3c31fdf31d
@ -12,11 +12,14 @@ import net.corda.core.serialization.serialize
|
|||||||
import net.corda.core.utilities.*
|
import net.corda.core.utilities.*
|
||||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||||
import net.corda.nodeapi.RPCApi
|
import net.corda.nodeapi.RPCApi
|
||||||
|
import net.corda.testing.SerializationEnvironmentRule
|
||||||
import net.corda.testing.driver.poll
|
import net.corda.testing.driver.poll
|
||||||
import net.corda.testing.internal.*
|
import net.corda.testing.internal.*
|
||||||
import org.apache.activemq.artemis.api.core.SimpleString
|
import org.apache.activemq.artemis.api.core.SimpleString
|
||||||
|
import org.junit.After
|
||||||
import org.junit.Assert.assertEquals
|
import org.junit.Assert.assertEquals
|
||||||
import org.junit.Assert.assertTrue
|
import org.junit.Assert.assertTrue
|
||||||
|
import org.junit.Rule
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import rx.subjects.PublishSubject
|
import rx.subjects.PublishSubject
|
||||||
@ -26,6 +29,14 @@ import java.util.concurrent.*
|
|||||||
import java.util.concurrent.atomic.AtomicInteger
|
import java.util.concurrent.atomic.AtomicInteger
|
||||||
|
|
||||||
class RPCStabilityTests {
|
class RPCStabilityTests {
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val testSerialization = SerializationEnvironmentRule(true)
|
||||||
|
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
|
||||||
|
@After
|
||||||
|
fun shutdown() {
|
||||||
|
pool.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
object DummyOps : RPCOps {
|
object DummyOps : RPCOps {
|
||||||
override val protocolVersion = 0
|
override val protocolVersion = 0
|
||||||
@ -197,9 +208,9 @@ class RPCStabilityTests {
|
|||||||
val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
|
val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
|
||||||
// Leak many observables
|
// Leak many observables
|
||||||
val N = 200
|
val N = 200
|
||||||
(1..N).toList().parallelStream().forEach {
|
(1..N).map {
|
||||||
proxy.leakObservable()
|
pool.fork { proxy.leakObservable(); Unit }
|
||||||
}
|
}.transpose().getOrThrow()
|
||||||
// In a loop force GC and check whether the server is notified
|
// In a loop force GC and check whether the server is notified
|
||||||
while (true) {
|
while (true) {
|
||||||
System.gc()
|
System.gc()
|
||||||
@ -231,7 +242,7 @@ class RPCStabilityTests {
|
|||||||
assertEquals("pong", client.ping())
|
assertEquals("pong", client.ping())
|
||||||
serverFollower.shutdown()
|
serverFollower.shutdown()
|
||||||
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
|
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
|
||||||
val pingFuture = ForkJoinPool.commonPool().fork(client::ping)
|
val pingFuture = pool.fork(client::ping)
|
||||||
assertEquals("pong", pingFuture.getOrThrow(10.seconds))
|
assertEquals("pong", pingFuture.getOrThrow(10.seconds))
|
||||||
clientFollower.shutdown() // Driver would do this after the new server, causing hang.
|
clientFollower.shutdown() // Driver would do this after the new server, causing hang.
|
||||||
}
|
}
|
||||||
|
@ -6,14 +6,20 @@ import net.corda.core.internal.concurrent.map
|
|||||||
import net.corda.core.messaging.RPCOps
|
import net.corda.core.messaging.RPCOps
|
||||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||||
import net.corda.nodeapi.User
|
import net.corda.nodeapi.User
|
||||||
|
import net.corda.testing.SerializationEnvironmentRule
|
||||||
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
||||||
import net.corda.testing.internal.rpcTestUser
|
import net.corda.testing.internal.rpcTestUser
|
||||||
import net.corda.testing.internal.startInVmRpcClient
|
import net.corda.testing.internal.startInVmRpcClient
|
||||||
import net.corda.testing.internal.startRpcClient
|
import net.corda.testing.internal.startRpcClient
|
||||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||||
|
import org.junit.Rule
|
||||||
import org.junit.runners.Parameterized
|
import org.junit.runners.Parameterized
|
||||||
|
|
||||||
open class AbstractRPCTest {
|
open class AbstractRPCTest {
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val testSerialization = SerializationEnvironmentRule(true)
|
||||||
|
|
||||||
enum class RPCTestMode {
|
enum class RPCTestMode {
|
||||||
InVm,
|
InVm,
|
||||||
Netty
|
Netty
|
||||||
|
@ -5,19 +5,22 @@ import net.corda.core.messaging.RPCOps
|
|||||||
import net.corda.core.utilities.millis
|
import net.corda.core.utilities.millis
|
||||||
import net.corda.core.crypto.random63BitValue
|
import net.corda.core.crypto.random63BitValue
|
||||||
import net.corda.core.internal.concurrent.fork
|
import net.corda.core.internal.concurrent.fork
|
||||||
|
import net.corda.core.internal.concurrent.transpose
|
||||||
import net.corda.core.serialization.CordaSerializable
|
import net.corda.core.serialization.CordaSerializable
|
||||||
|
import net.corda.core.utilities.getOrThrow
|
||||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||||
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
import net.corda.testing.internal.RPCDriverExposedDSLInterface
|
||||||
import net.corda.testing.internal.rpcDriver
|
import net.corda.testing.internal.rpcDriver
|
||||||
|
import net.corda.testing.internal.testThreadFactory
|
||||||
|
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet
|
||||||
|
import org.junit.After
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import org.junit.runner.RunWith
|
import org.junit.runner.RunWith
|
||||||
import org.junit.runners.Parameterized
|
import org.junit.runners.Parameterized
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import rx.subjects.UnicastSubject
|
import rx.subjects.UnicastSubject
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import java.util.concurrent.ConcurrentHashMap
|
import java.util.concurrent.*
|
||||||
import java.util.concurrent.CountDownLatch
|
|
||||||
import java.util.concurrent.ForkJoinPool
|
|
||||||
|
|
||||||
@RunWith(Parameterized::class)
|
@RunWith(Parameterized::class)
|
||||||
class RPCConcurrencyTests : AbstractRPCTest() {
|
class RPCConcurrencyTests : AbstractRPCTest() {
|
||||||
@ -36,7 +39,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
|
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestOpsImpl : TestOps {
|
class TestOpsImpl(private val pool: Executor) : TestOps {
|
||||||
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
|
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
|
||||||
override val protocolVersion = 0
|
override val protocolVersion = 0
|
||||||
|
|
||||||
@ -68,24 +71,22 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
val branches = if (depth == 0) {
|
val branches = if (depth == 0) {
|
||||||
Observable.empty<ObservableRose<Int>>()
|
Observable.empty<ObservableRose<Int>>()
|
||||||
} else {
|
} else {
|
||||||
val publish = UnicastSubject.create<ObservableRose<Int>>()
|
UnicastSubject.create<ObservableRose<Int>>().also { publish ->
|
||||||
ForkJoinPool.commonPool().fork {
|
(1..branchingFactor).map {
|
||||||
(1..branchingFactor).toList().parallelStream().forEach {
|
pool.fork { publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) }
|
||||||
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
|
}.transpose().then {
|
||||||
|
it.getOrThrow()
|
||||||
|
publish.onCompleted()
|
||||||
}
|
}
|
||||||
publish.onCompleted()
|
|
||||||
}
|
}
|
||||||
publish
|
|
||||||
}
|
}
|
||||||
return ObservableRose(depth, branches)
|
return ObservableRose(depth, branches)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private lateinit var testOpsImpl: TestOpsImpl
|
|
||||||
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
|
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
|
||||||
testOpsImpl = TestOpsImpl()
|
|
||||||
return testProxy<TestOps>(
|
return testProxy<TestOps>(
|
||||||
testOpsImpl,
|
TestOpsImpl(pool),
|
||||||
clientConfiguration = RPCClientConfiguration.default.copy(
|
clientConfiguration = RPCClientConfiguration.default.copy(
|
||||||
reapInterval = 100.millis,
|
reapInterval = 100.millis,
|
||||||
cacheConcurrencyLevel = 16
|
cacheConcurrencyLevel = 16
|
||||||
@ -96,6 +97,12 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val pool = Executors.newFixedThreadPool(10, testThreadFactory())
|
||||||
|
@After
|
||||||
|
fun shutdown() {
|
||||||
|
pool.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `call multiple RPCs in parallel`() {
|
fun `call multiple RPCs in parallel`() {
|
||||||
rpcDriver {
|
rpcDriver {
|
||||||
@ -103,19 +110,17 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
val numberOfBlockedCalls = 2
|
val numberOfBlockedCalls = 2
|
||||||
val numberOfDownsRequired = 100
|
val numberOfDownsRequired = 100
|
||||||
val id = proxy.ops.newLatch(numberOfDownsRequired)
|
val id = proxy.ops.newLatch(numberOfDownsRequired)
|
||||||
val done = CountDownLatch(numberOfBlockedCalls)
|
|
||||||
// Start a couple of blocking RPC calls
|
// Start a couple of blocking RPC calls
|
||||||
(1..numberOfBlockedCalls).forEach {
|
val done = (1..numberOfBlockedCalls).map {
|
||||||
ForkJoinPool.commonPool().fork {
|
pool.fork {
|
||||||
proxy.ops.waitLatch(id)
|
proxy.ops.waitLatch(id)
|
||||||
done.countDown()
|
|
||||||
}
|
}
|
||||||
}
|
}.transpose()
|
||||||
// Down the latch that the others are waiting for concurrently
|
// Down the latch that the others are waiting for concurrently
|
||||||
(1..numberOfDownsRequired).toList().parallelStream().forEach {
|
(1..numberOfDownsRequired).map {
|
||||||
proxy.ops.downLatch(id)
|
pool.fork { proxy.ops.downLatch(id) }
|
||||||
}
|
}.transpose().getOrThrow()
|
||||||
done.await()
|
done.getOrThrow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +151,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
fun ObservableRose<Int>.subscribeToAll() {
|
fun ObservableRose<Int>.subscribeToAll() {
|
||||||
remainingLatch.countDown()
|
remainingLatch.countDown()
|
||||||
this.branches.subscribe { tree ->
|
this.branches.subscribe { tree ->
|
||||||
(tree.value + 1..treeDepth - 1).forEach {
|
(tree.value + 1 until treeDepth).forEach {
|
||||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||||
}
|
}
|
||||||
depthsSeen.add(tree.value)
|
depthsSeen.add(tree.value)
|
||||||
@ -165,11 +170,11 @@ class RPCConcurrencyTests : AbstractRPCTest() {
|
|||||||
val treeDepth = 2
|
val treeDepth = 2
|
||||||
val treeBranchingFactor = 10
|
val treeBranchingFactor = 10
|
||||||
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
|
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
|
||||||
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
|
val depthsSeen = ConcurrentHashSet<Int>()
|
||||||
fun ObservableRose<Int>.subscribeToAll() {
|
fun ObservableRose<Int>.subscribeToAll() {
|
||||||
remainingLatch.countDown()
|
remainingLatch.countDown()
|
||||||
branches.subscribe { tree ->
|
branches.subscribe { tree ->
|
||||||
(tree.value + 1..treeDepth - 1).forEach {
|
(tree.value + 1 until treeDepth).forEach {
|
||||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||||
}
|
}
|
||||||
depthsSeen.add(tree.value)
|
depthsSeen.add(tree.value)
|
||||||
|
@ -5,12 +5,18 @@ import net.corda.core.concurrent.CordaFuture
|
|||||||
import net.corda.core.internal.concurrent.openFuture
|
import net.corda.core.internal.concurrent.openFuture
|
||||||
import net.corda.core.messaging.*
|
import net.corda.core.messaging.*
|
||||||
import net.corda.core.utilities.getOrThrow
|
import net.corda.core.utilities.getOrThrow
|
||||||
|
import net.corda.testing.SerializationEnvironmentRule
|
||||||
import net.corda.testing.internal.rpcDriver
|
import net.corda.testing.internal.rpcDriver
|
||||||
import net.corda.testing.internal.startRpcClient
|
import net.corda.testing.internal.startRpcClient
|
||||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||||
|
import org.junit.Rule
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
|
|
||||||
class RPCFailureTests {
|
class RPCFailureTests {
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val testSerialization = SerializationEnvironmentRule(true)
|
||||||
|
|
||||||
class Unserializable
|
class Unserializable
|
||||||
interface Ops : RPCOps {
|
interface Ops : RPCOps {
|
||||||
fun getUnserializable(): Unserializable
|
fun getUnserializable(): Unserializable
|
||||||
|
@ -31,6 +31,8 @@ import java.time.Duration
|
|||||||
import java.time.temporal.Temporal
|
import java.time.temporal.Temporal
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import java.util.Spliterator.*
|
import java.util.Spliterator.*
|
||||||
|
import java.util.concurrent.ExecutorService
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
import java.util.stream.IntStream
|
import java.util.stream.IntStream
|
||||||
import java.util.stream.Stream
|
import java.util.stream.Stream
|
||||||
import java.util.stream.StreamSupport
|
import java.util.stream.StreamSupport
|
||||||
@ -307,3 +309,10 @@ fun TransactionBuilder.toLedgerTransaction(services: ServiceHub, serializationCo
|
|||||||
val KClass<*>.packageName: String get() = java.`package`.name
|
val KClass<*>.packageName: String get() = java.`package`.name
|
||||||
|
|
||||||
fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection
|
fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection
|
||||||
|
/** Analogous to [Thread.join]. */
|
||||||
|
fun ExecutorService.join() {
|
||||||
|
shutdown() // Do not change to shutdownNow, tests use this method to assert the executor has no more tasks.
|
||||||
|
while (!awaitTermination(1, TimeUnit.SECONDS)) {
|
||||||
|
// Try forever. Do not give up, tests use this method to assert the executor has no more tasks.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -133,7 +133,7 @@ class ContractUpgradeFlowTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `2 parties contract upgrade using RPC`() {
|
fun `2 parties contract upgrade using RPC`() {
|
||||||
rpcDriver(initialiseSerialization = false) {
|
rpcDriver {
|
||||||
// Create dummy contract.
|
// Create dummy contract.
|
||||||
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1))
|
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, alice.ref(1), bob.ref(1))
|
||||||
val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract)
|
val signedByA = aliceNode.services.signInitialTransaction(twoPartyDummyContract)
|
||||||
|
@ -14,7 +14,6 @@ import org.junit.runners.model.Statement
|
|||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
import java.util.concurrent.ExecutorService
|
import java.util.concurrent.ExecutorService
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
import java.util.concurrent.TimeUnit
|
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertNull
|
import kotlin.test.assertNull
|
||||||
|
|
||||||
@ -23,10 +22,7 @@ private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Ex
|
|||||||
fork {}.getOrThrow() // Start the thread.
|
fork {}.getOrThrow() // Start the thread.
|
||||||
callable()
|
callable()
|
||||||
} finally {
|
} finally {
|
||||||
shutdown()
|
join()
|
||||||
while (!awaitTermination(1, TimeUnit.SECONDS)) {
|
|
||||||
// Do nothing.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,13 +2,13 @@ package net.corda.core.internal.concurrent
|
|||||||
|
|
||||||
import com.nhaarman.mockito_kotlin.*
|
import com.nhaarman.mockito_kotlin.*
|
||||||
import net.corda.core.concurrent.CordaFuture
|
import net.corda.core.concurrent.CordaFuture
|
||||||
|
import net.corda.core.internal.join
|
||||||
import net.corda.core.utilities.getOrThrow
|
import net.corda.core.utilities.getOrThrow
|
||||||
import net.corda.testing.rigorousMock
|
import net.corda.testing.rigorousMock
|
||||||
import org.assertj.core.api.Assertions
|
import org.assertj.core.api.Assertions
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
import java.util.concurrent.TimeUnit
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean
|
import java.util.concurrent.atomic.AtomicBoolean
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertFalse
|
import kotlin.test.assertFalse
|
||||||
@ -108,10 +108,7 @@ class CordaFutureTest {
|
|||||||
val throwable = Exception("Boom")
|
val throwable = Exception("Boom")
|
||||||
val executor = Executors.newSingleThreadExecutor()
|
val executor = Executors.newSingleThreadExecutor()
|
||||||
executor.fork { throw throwable }.andForget(log)
|
executor.fork { throw throwable }.andForget(log)
|
||||||
executor.shutdown()
|
executor.join()
|
||||||
while (!executor.awaitTermination(1, TimeUnit.SECONDS)) {
|
|
||||||
// Do nothing.
|
|
||||||
}
|
|
||||||
verify(log).error(any(), same(throwable))
|
verify(log).error(any(), same(throwable))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId
|
|||||||
import net.corda.core.identity.CordaX500Name
|
import net.corda.core.identity.CordaX500Name
|
||||||
import net.corda.core.internal.LazyStickyPool
|
import net.corda.core.internal.LazyStickyPool
|
||||||
import net.corda.core.internal.LifeCycle
|
import net.corda.core.internal.LifeCycle
|
||||||
|
import net.corda.core.internal.join
|
||||||
import net.corda.core.messaging.RPCOps
|
import net.corda.core.messaging.RPCOps
|
||||||
import net.corda.core.serialization.SerializationContext
|
import net.corda.core.serialization.SerializationContext
|
||||||
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
|
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
|
||||||
@ -207,6 +208,7 @@ class RPCServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun close() {
|
fun close() {
|
||||||
|
observationSendExecutor?.join()
|
||||||
reaperScheduledFuture?.cancel(false)
|
reaperScheduledFuture?.cancel(false)
|
||||||
rpcExecutor?.shutdownNow()
|
rpcExecutor?.shutdownNow()
|
||||||
reaperExecutor?.shutdownNow()
|
reaperExecutor?.shutdownNow()
|
||||||
|
@ -230,7 +230,6 @@ fun <A> rpcDriver(
|
|||||||
debugPortAllocation: PortAllocation = globalDebugPortAllocation,
|
debugPortAllocation: PortAllocation = globalDebugPortAllocation,
|
||||||
systemProperties: Map<String, String> = emptyMap(),
|
systemProperties: Map<String, String> = emptyMap(),
|
||||||
useTestClock: Boolean = false,
|
useTestClock: Boolean = false,
|
||||||
initialiseSerialization: Boolean = true,
|
|
||||||
startNodesInProcess: Boolean = false,
|
startNodesInProcess: Boolean = false,
|
||||||
waitForNodesToFinish: Boolean = false,
|
waitForNodesToFinish: Boolean = false,
|
||||||
extraCordappPackagesToScan: List<String> = emptyList(),
|
extraCordappPackagesToScan: List<String> = emptyList(),
|
||||||
@ -254,7 +253,7 @@ fun <A> rpcDriver(
|
|||||||
),
|
),
|
||||||
coerce = { it },
|
coerce = { it },
|
||||||
dsl = dsl,
|
dsl = dsl,
|
||||||
initialiseSerialization = initialiseSerialization
|
initialiseSerialization = false
|
||||||
)
|
)
|
||||||
|
|
||||||
private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {
|
private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {
|
||||||
|
@ -1,26 +1,50 @@
|
|||||||
package net.corda.testing
|
package net.corda.testing
|
||||||
|
|
||||||
import com.nhaarman.mockito_kotlin.doNothing
|
import com.nhaarman.mockito_kotlin.*
|
||||||
import com.nhaarman.mockito_kotlin.whenever
|
|
||||||
import net.corda.client.rpc.internal.KryoClientSerializationScheme
|
import net.corda.client.rpc.internal.KryoClientSerializationScheme
|
||||||
|
import net.corda.core.internal.staticField
|
||||||
import net.corda.core.serialization.internal.*
|
import net.corda.core.serialization.internal.*
|
||||||
import net.corda.node.serialization.KryoServerSerializationScheme
|
import net.corda.node.serialization.KryoServerSerializationScheme
|
||||||
import net.corda.nodeapi.internal.serialization.*
|
import net.corda.nodeapi.internal.serialization.*
|
||||||
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
|
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
|
||||||
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
|
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
|
||||||
import net.corda.testing.common.internal.asContextEnv
|
import net.corda.testing.common.internal.asContextEnv
|
||||||
|
import net.corda.testing.internal.testThreadFactory
|
||||||
|
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
|
||||||
import org.junit.rules.TestRule
|
import org.junit.rules.TestRule
|
||||||
import org.junit.runner.Description
|
import org.junit.runner.Description
|
||||||
import org.junit.runners.model.Statement
|
import org.junit.runners.model.Statement
|
||||||
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
import java.util.concurrent.ExecutorService
|
||||||
|
import java.util.concurrent.Executors
|
||||||
|
|
||||||
|
private val inVMExecutors = ConcurrentHashMap<SerializationEnvironment, ExecutorService>()
|
||||||
|
|
||||||
/** @param inheritable whether new threads inherit the environment, use sparingly. */
|
/** @param inheritable whether new threads inherit the environment, use sparingly. */
|
||||||
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
|
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
|
||||||
|
companion object {
|
||||||
|
init {
|
||||||
|
// Can't turn it off, and it creates threads that do serialization, so hack it:
|
||||||
|
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
|
||||||
|
doAnswer {
|
||||||
|
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
|
||||||
|
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
|
||||||
|
}.execute(it.arguments[0] as Runnable)
|
||||||
|
}.whenever(it).execute(any())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
lateinit var env: SerializationEnvironment
|
lateinit var env: SerializationEnvironment
|
||||||
override fun apply(base: Statement, description: Description): Statement {
|
override fun apply(base: Statement, description: Description): Statement {
|
||||||
env = createTestSerializationEnv(description.toString())
|
env = createTestSerializationEnv(description.toString())
|
||||||
return object : Statement() {
|
return object : Statement() {
|
||||||
override fun evaluate() = env.asContextEnv(inheritable) {
|
override fun evaluate() {
|
||||||
base.evaluate()
|
try {
|
||||||
|
env.asContextEnv(inheritable) { base.evaluate() }
|
||||||
|
} finally {
|
||||||
|
inVMExecutors.remove(env)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,6 +83,7 @@ fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
|
|||||||
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
|
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
|
||||||
override fun unset() {
|
override fun unset() {
|
||||||
_globalSerializationEnv.set(null)
|
_globalSerializationEnv.set(null)
|
||||||
|
inVMExecutors.remove(this)
|
||||||
}
|
}
|
||||||
}.also {
|
}.also {
|
||||||
_globalSerializationEnv.set(it)
|
_globalSerializationEnv.set(it)
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
package net.corda.testing.internal
|
||||||
|
|
||||||
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
import java.util.concurrent.ThreadFactory
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger
|
||||||
|
|
||||||
|
private val familyToNextPoolNumber = ConcurrentHashMap<String, AtomicInteger>()
|
||||||
|
fun Any.testThreadFactory(useEnclosingClassName: Boolean = false): ThreadFactory {
|
||||||
|
val poolFamily = javaClass.let { (if (useEnclosingClassName) it.enclosingClass else it).simpleName }
|
||||||
|
val poolNumber = familyToNextPoolNumber.computeIfAbsent(poolFamily) { AtomicInteger(1) }.getAndIncrement()
|
||||||
|
val nextThreadNumber = AtomicInteger(1)
|
||||||
|
return ThreadFactory { task ->
|
||||||
|
Thread(task, "$poolFamily-$poolNumber-${nextThreadNumber.getAndIncrement()}")
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user