CORDA-716 Make serialization init less static (#1996)

This commit is contained in:
Andrzej Cichocki 2017-11-10 15:44:43 +00:00 committed by GitHub
parent cc4c732a48
commit 052124bbe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 463 additions and 346 deletions

View File

@ -2430,19 +2430,13 @@ public static final class net.corda.core.serialization.SerializationContext$UseC
public static net.corda.core.serialization.SerializationContext$UseCase valueOf(String)
public static net.corda.core.serialization.SerializationContext$UseCase[] values()
##
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object implements net.corda.core.serialization.internal.SerializationEnvironment
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT()
public void setCHECKPOINT_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setP2P_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setRPC_CLIENT_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setRPC_SERVER_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setSERIALIZATION_FACTORY(net.corda.core.serialization.SerializationFactory)
public void setSTORAGE_CONTEXT(net.corda.core.serialization.SerializationContext)
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT()
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT()
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY()
@org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT()
public static final net.corda.core.serialization.SerializationDefaults INSTANCE
##
public abstract class net.corda.core.serialization.SerializationFactory extends java.lang.Object

View File

@ -4,10 +4,10 @@ import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.internal.serialization.AMQP_RPC_CLIENT_CONTEXT
import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT
import java.time.Duration
@ -71,8 +71,15 @@ class CordaRPCClient @JvmOverloads constructor(
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT
) {
init {
// TODO: allow clients to have serialization factory etc injected and align with RPC protocol version?
KryoClientSerializationScheme.initialiseSerialization()
try {
effectiveSerializationEnv
} catch (e: IllegalStateException) {
try {
KryoClientSerializationScheme.initialiseSerialization()
} catch (e: IllegalStateException) {
// Race e.g. two of these constructed in parallel, ignore.
}
}
}
private val rpcClient = RPCClient<CordaRPCOps>(

View File

@ -2,15 +2,17 @@ package net.corda.client.rpc.internal
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT
import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT
import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.nodeapi.internal.serialization.kryo.AbstractKryoSerializationScheme
import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.nodeapi.internal.serialization.kryo.RPCKryo
import java.util.concurrent.atomic.AtomicBoolean
class KryoClientSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
@ -29,25 +31,15 @@ class KryoClientSerializationScheme : AbstractKryoSerializationScheme() {
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
companion object {
val isInitialised = AtomicBoolean(false)
/** Call from main only. */
fun initialiseSerialization() {
if (!isInitialised.compareAndSet(false, true)) return
try {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme())
registerScheme(AMQPClientSerializationScheme())
}
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT
SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT
} catch (e: IllegalStateException) {
// Check that it's registered as we expect
val factory = SerializationDefaults.SERIALIZATION_FACTORY
val checkedFactory = factory as? SerializationFactoryImpl
?: throw IllegalStateException("RPC client encountered conflicting configuration of serialization subsystem: $factory")
check(checkedFactory.alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) {
"RPC client encountered conflicting configuration of serialization subsystem."
}
}
nodeSerializationEnv = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme())
registerScheme(AMQPClientSerializationScheme())
},
KRYO_P2P_CONTEXT,
rpcClientContext = KRYO_RPC_CLIENT_CONTEXT)
}
}
}

View File

@ -8,6 +8,7 @@ import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.minutes
@ -110,6 +111,7 @@ class RPCClient<I : RPCOps>(
maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis()
reconnectAttempts = rpcConfiguration.maxReconnectAttempts
minLargeMessageSize = rpcConfiguration.maxFileSize
isUseGlobalPools = nodeSerializationEnv != null
}
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext)

View File

@ -0,0 +1,66 @@
package net.corda.core.internal
import java.util.concurrent.atomic.AtomicReference
import kotlin.reflect.KProperty
/** May go from null to non-null and vice-versa, and that's it. */
abstract class ToggleField<T>(val name: String) {
private val writeMutex = Any() // Protects the toggle logic only.
abstract fun get(): T?
fun set(value: T?) = synchronized(writeMutex) {
if (value != null) {
check(get() == null) { "$name already has a value." }
setImpl(value)
} else {
check(get() != null) { "$name is already null." }
clear()
}
}
protected abstract fun setImpl(value: T)
protected abstract fun clear()
operator fun getValue(thisRef: Any?, property: KProperty<*>) = get()
operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T?) = set(value)
}
class SimpleToggleField<T>(name: String, private val once: Boolean = false) : ToggleField<T>(name) {
private val holder = AtomicReference<T?>() // Force T? in API for safety.
override fun get() = holder.get()
override fun setImpl(value: T) = holder.set(value)
override fun clear() {
check(!once) { "Value of $name cannot be changed." }
holder.set(null)
}
}
class ThreadLocalToggleField<T>(name: String) : ToggleField<T>(name) {
private val threadLocal = ThreadLocal<T?>()
override fun get() = threadLocal.get()
override fun setImpl(value: T) = threadLocal.set(value)
override fun clear() = threadLocal.remove()
}
/** The named thread has leaked from a previous test. */
class ThreadLeakException : RuntimeException("Leaked thread detected: ${Thread.currentThread().name}")
class InheritableThreadLocalToggleField<T>(name: String) : ToggleField<T>(name) {
private class Holder<T>(value: T) : AtomicReference<T?>(value) {
fun valueOrDeclareLeak() = get() ?: throw ThreadLeakException()
}
private val threadLocal = object : InheritableThreadLocal<Holder<T>?>() {
override fun childValue(holder: Holder<T>?): Holder<T>? {
// The Holder itself may be null due to prior events, a leak is not implied in that case:
holder?.valueOrDeclareLeak() // Fail fast.
return holder // What super does.
}
}
override fun get() = threadLocal.get()?.valueOrDeclareLeak()
override fun setImpl(value: T) = threadLocal.set(Holder(value))
override fun clear() = threadLocal.run {
val holder = get()!!
remove()
holder.set(null) // Threads that inherited the holder are now considered to have escaped from the test.
}
}

View File

@ -1,18 +0,0 @@
package net.corda.core.internal
import kotlin.reflect.KProperty
/**
* A write-once property to be used as delegate for Kotlin var properties. The expectation is that this is initialised
* prior to the spawning of any threads that may access it and so there's no need for it to be volatile.
*/
class WriteOnceProperty<T : Any>(private val defaultValue: T? = null) {
private var v: T? = defaultValue
operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: throw IllegalStateException("Write-once property $property not set.")
operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) {
check(v == defaultValue || v === value) { "Cannot set write-once property $property more than once." }
v = value
}
}

View File

@ -2,8 +2,7 @@ package net.corda.core.serialization
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.internal.WriteOnceProperty
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.sequence
@ -53,7 +52,7 @@ abstract class SerializationFactory {
* A context to use as a default if you do not require a specially configured context. It will be the current context
* if the use is somehow nested (see [currentContext]).
*/
val defaultContext: SerializationContext get() = currentContext ?: SerializationDefaults.P2P_CONTEXT
val defaultContext: SerializationContext get() = currentContext ?: effectiveSerializationEnv.p2pContext
private val _currentContext = ThreadLocal<SerializationContext?>()
@ -90,7 +89,7 @@ abstract class SerializationFactory {
/**
* A default factory for serialization/deserialization, taking into account the [currentFactory] if set.
*/
val defaultFactory: SerializationFactory get() = currentFactory ?: SerializationDefaults.SERIALIZATION_FACTORY
val defaultFactory: SerializationFactory get() = currentFactory ?: effectiveSerializationEnv.serializationFactory
/**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
@ -173,13 +172,13 @@ interface SerializationContext {
/**
* Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client).
*/
object SerializationDefaults : SerializationEnvironment {
override var SERIALIZATION_FACTORY: SerializationFactory by WriteOnceProperty()
override var P2P_CONTEXT: SerializationContext by WriteOnceProperty()
override var RPC_SERVER_CONTEXT: SerializationContext by WriteOnceProperty()
override var RPC_CLIENT_CONTEXT: SerializationContext by WriteOnceProperty()
override var STORAGE_CONTEXT: SerializationContext by WriteOnceProperty()
override var CHECKPOINT_CONTEXT: SerializationContext by WriteOnceProperty()
object SerializationDefaults {
val SERIALIZATION_FACTORY get() = effectiveSerializationEnv.serializationFactory
val P2P_CONTEXT get() = effectiveSerializationEnv.p2pContext
val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext
val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext
val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
}
/**

View File

@ -1,13 +1,55 @@
package net.corda.core.serialization.internal
import net.corda.core.internal.InheritableThreadLocalToggleField
import net.corda.core.internal.SimpleToggleField
import net.corda.core.internal.ThreadLocalToggleField
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
interface SerializationEnvironment {
val SERIALIZATION_FACTORY: SerializationFactory
val P2P_CONTEXT: SerializationContext
val RPC_SERVER_CONTEXT: SerializationContext
val RPC_CLIENT_CONTEXT: SerializationContext
val STORAGE_CONTEXT: SerializationContext
val CHECKPOINT_CONTEXT: SerializationContext
val serializationFactory: SerializationFactory
val p2pContext: SerializationContext
val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext
val storageContext: SerializationContext
val checkpointContext: SerializationContext
}
class SerializationEnvironmentImpl(
override val serializationFactory: SerializationFactory,
override val p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: SerializationContext? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: SerializationContext
init {
rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it }
}
}
private val _nodeSerializationEnv = SimpleToggleField<SerializationEnvironment>("nodeSerializationEnv", true)
@VisibleForTesting
val _globalSerializationEnv = SimpleToggleField<SerializationEnvironment>("globalSerializationEnv")
@VisibleForTesting
val _contextSerializationEnv = ThreadLocalToggleField<SerializationEnvironment>("contextSerializationEnv")
@VisibleForTesting
val _inheritableContextSerializationEnv = InheritableThreadLocalToggleField<SerializationEnvironment>("inheritableContextSerializationEnv")
private val serializationEnvProperties = listOf(_nodeSerializationEnv, _globalSerializationEnv, _contextSerializationEnv, _inheritableContextSerializationEnv)
val effectiveSerializationEnv: SerializationEnvironment
get() = serializationEnvProperties.map { Pair(it, it.get()) }.filter { it.second != null }.run {
singleOrNull()?.run {
second!!
} ?: throw IllegalStateException("Expected exactly 1 of {${serializationEnvProperties.joinToString(", ") { it.name }}} but got: {${joinToString(", ") { it.first.name }}}")
}
/** Should be set once in main. */
var nodeSerializationEnv by _nodeSerializationEnv

View File

@ -0,0 +1,125 @@
package net.corda.core.internal
import net.corda.core.internal.concurrent.fork
import net.corda.core.utilities.getOrThrow
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertNull
private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Executors.newSingleThreadExecutor().run {
try {
fork {}.getOrThrow() // Start the thread.
callable()
} finally {
shutdown()
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
}
}
class ToggleFieldTest {
@Test
fun `toggle is enforced`() {
listOf(SimpleToggleField<String>("simple"), ThreadLocalToggleField<String>("local"), InheritableThreadLocalToggleField("inheritable")).forEach { field ->
assertNull(field.get())
assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java)
field.set("hello")
assertEquals("hello", field.get())
assertThatThrownBy { field.set("world") }.isInstanceOf(IllegalStateException::class.java)
assertEquals("hello", field.get())
assertThatThrownBy { field.set("hello") }.isInstanceOf(IllegalStateException::class.java)
field.set(null)
assertNull(field.get())
}
}
@Test
fun `write-at-most-once field works`() {
val field = SimpleToggleField<String>("field", true)
assertNull(field.get())
assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java)
field.set("finalValue")
assertEquals("finalValue", field.get())
listOf("otherValue", "finalValue", null).forEach { value ->
assertThatThrownBy { field.set(value) }.isInstanceOf(IllegalStateException::class.java)
assertEquals("finalValue", field.get())
}
}
@Test
fun `thread local works`() {
val field = ThreadLocalToggleField<String>("field")
assertNull(field.get())
field.set("hello")
assertEquals("hello", field.get())
withSingleThreadExecutor {
assertNull(fork(field::get).getOrThrow())
}
field.set(null)
assertNull(field.get())
}
@Test
fun `inheritable thread local works`() {
val field = InheritableThreadLocalToggleField<String>("field")
assertNull(field.get())
field.set("hello")
assertEquals("hello", field.get())
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
}
field.set(null)
assertNull(field.get())
}
@Test
fun `existing threads do not inherit`() {
val field = InheritableThreadLocalToggleField<String>("field")
withSingleThreadExecutor {
field.set("hello")
assertEquals("hello", field.get())
assertNull(fork(field::get).getOrThrow())
}
}
@Test
fun `inherited values are poisoned on clear`() {
val field = InheritableThreadLocalToggleField<String>("field")
field.set("hello")
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
val threadName = fork { Thread.currentThread().name }.getOrThrow()
listOf(null, "world").forEach { value ->
field.set(value)
assertEquals(value, field.get())
val future = fork(field::get)
assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(ThreadLeakException::class.java)
.hasMessageContaining(threadName)
}
}
withSingleThreadExecutor {
assertEquals("world", fork(field::get).getOrThrow())
}
}
@Test
fun `leaked thread is detected as soon as it tries to create another`() {
val field = InheritableThreadLocalToggleField<String>("field")
field.set("hello")
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
field.set(null) // The executor thread is now considered leaked.
val threadName = fork { Thread.currentThread().name }.getOrThrow()
val future = fork(::Thread)
assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(ThreadLeakException::class.java)
.hasMessageContaining(threadName)
}
}
}

View File

@ -1,6 +1,7 @@
package net.corda.nodeapi
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.TransportConfiguration
@ -48,7 +49,8 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats.
// TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications.
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP"
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP",
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null)
)
if (config != null && enableSSL) {

View File

@ -25,7 +25,7 @@ public final class ForbiddenLambdaSerializationTests {
@Before
public void setup() {
factory = testSerialization.env.getSERIALIZATION_FACTORY();
factory = testSerialization.getEnv().getSerializationFactory();
}
@Test

View File

@ -25,7 +25,7 @@ public final class LambdaCheckpointSerializationTest {
@Before
public void setup() {
factory = testSerialization.env.getSERIALIZATION_FACTORY();
factory = testSerialization.getEnv().getSerializationFactory();
context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint);
}

View File

@ -27,9 +27,8 @@ class ContractAttachmentSerializerTest {
@Before
fun setup() {
factory = testSerialization.env.SERIALIZATION_FACTORY
context = testSerialization.env.CHECKPOINT_CONTEXT
factory = testSerialization.env.serializationFactory
context = testSerialization.env.checkpointContext
contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices))
}

View File

@ -8,7 +8,6 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.kryo.CordaKryo
import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.rigorousMock
import net.corda.testing.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
@ -26,8 +25,8 @@ class SerializationTokenTest {
@Before
fun setup() {
factory = testSerialization.env.SERIALIZATION_FACTORY
context = testSerialization.env.CHECKPOINT_CONTEXT.withWhitelisted(SingletonSerializationToken::class.java)
factory = testSerialization.env.serializationFactory
context = testSerialization.env.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java)
}
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized

View File

@ -17,26 +17,21 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.loggerFor
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.testing.*
import net.corda.testing.DUMMY_BANK_A
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.driver.DriverDSLExposedInterface
import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.driver
import net.corda.testing.node.MockServices
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.net.URLClassLoader
import java.nio.file.Files
import kotlin.test.assertFailsWith
class AttachmentLoadingTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private class Services : MockServices() {
private val provider = CordappProviderImpl(CordappLoader.createDevMode(listOf(isolatedJAR)), attachments)
private val cordapp get() = provider.cordapps.first()
@ -83,7 +78,7 @@ class AttachmentLoadingTests {
}
@Test
fun `test a wire transaction has loaded the correct attachment`() {
fun `test a wire transaction has loaded the correct attachment`() = withTestSerialization {
val appClassLoader = services.appContext.classLoader
val contractClass = appClassLoader.loadClass(ISOLATED_CONTRACT_ID).asSubclass(Contract::class.java)
val generateInitialMethod = contractClass.getDeclaredMethod("generateInitial", PartyAndReference::class.java, Integer.TYPE, Party::class.java)
@ -101,7 +96,7 @@ class AttachmentLoadingTests {
@Test
fun `test that attachments retrieved over the network are not used for code`() {
driver(initialiseSerialization = false) {
driver {
installIsolatedCordappTo(bankAName)
val (bankA, bankB) = createTwoNodes()
assertFailsWith<UnexpectedFlowEndException>("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") {
@ -112,7 +107,7 @@ class AttachmentLoadingTests {
@Test
fun `tests that if the attachment is loaded on both sides already that a flow can run`() {
driver(initialiseSerialization = false) {
driver {
installIsolatedCordappTo(bankAName)
installIsolatedCordappTo(bankBName)
val (bankA, bankB) = createTwoNodes()

View File

@ -10,11 +10,13 @@ import net.corda.nodeapi.NodeInfoFilesCopier
import net.corda.testing.ALICE
import net.corda.testing.ALICE_KEY
import net.corda.testing.getTestPartyAndCertificate
import net.corda.testing.internal.NodeBasedTest
import net.corda.testing.*
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.contentOf
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import rx.observers.TestSubscriber
import rx.schedulers.TestScheduler
import java.nio.file.Path
@ -22,11 +24,17 @@ import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class NodeInfoWatcherTest : NodeBasedTest() {
class NodeInfoWatcherTest {
companion object {
val nodeInfo = NodeInfo(listOf(), listOf(getTestPartyAndCertificate(ALICE)), 0, 0)
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
@Rule
@JvmField
val tempFolder = TemporaryFolder()
private lateinit var nodeInfoPath: Path
private val scheduler = TestScheduler()
private val testSubscriber = TestSubscriber<NodeInfo>()

View File

@ -9,9 +9,10 @@ import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.RPCOps
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.VersionInfo
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.serialization.KryoServerSerializationScheme
@ -25,7 +26,6 @@ import net.corda.node.services.messaging.NodeMessagingClient
import net.corda.node.utilities.AddressUtils
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.DemoClock
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.internal.ShutdownHook
import net.corda.nodeapi.internal.addShutdownHook
import net.corda.nodeapi.internal.serialization.*
@ -274,14 +274,15 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() {
val classloader = cordappLoader.appClassLoader
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPServerSerializationScheme())
}
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT.withClassLoader(classloader)
SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader)
SerializationDefaults.STORAGE_CONTEXT = KRYO_STORAGE_CONTEXT.withClassLoader(classloader)
SerializationDefaults.CHECKPOINT_CONTEXT = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
nodeSerializationEnv = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPServerSerializationScheme())
},
KRYO_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = KRYO_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader))
}
/** Starts a blocking event loop for message dispatch. */

View File

@ -11,6 +11,7 @@ import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.TransactionVerifierService
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.serialization.serialize
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort
@ -217,6 +218,7 @@ class NodeMessagingClient(override val config: NodeConfiguration,
locator.connectionTTL = -1
locator.clientFailureCheckPeriod = -1
locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE
locator.isUseGlobalPools = nodeSerializationEnv != null
sessionFactory = locator.createSessionFactory()
// Login using the node username. The broker will authentiate us as its node (as opposed to another peer)

View File

@ -60,7 +60,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testSerialization = SerializationEnvironmentRule(true)
private val realClock: Clock = Clock.systemUTC()
private val stoppedClock: Clock = Clock.fixed(realClock.instant(), realClock.zone)
private val testClock = TestClock(stoppedClock)

View File

@ -41,7 +41,7 @@ import kotlin.test.assertEquals
class HTTPNetworkMapClientTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testSerialization = SerializationEnvironmentRule(true)
private lateinit var server: Server
private lateinit var networkMapClient: NetworkMapClient

View File

@ -30,7 +30,7 @@ class DistributedImmutableMapTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testSerialization = SerializationEnvironmentRule(true)
lateinit var cluster: List<Member>
lateinit var transaction: DatabaseTransaction
private val databases: MutableList<CordaPersistence> = mutableListOf()

View File

@ -1,9 +1,12 @@
package net.corda.node.services.vault
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.InsufficientBalanceException
import net.corda.core.contracts.LinearState
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.identity.AnonymousParty
import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose
import net.corda.core.internal.packageName
import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultService
@ -11,6 +14,7 @@ import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.getOrThrow
import net.corda.finance.*
import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.DUMMY_CASH_ISSUER
@ -29,9 +33,9 @@ import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import kotlin.test.assertEquals
import kotlin.test.fail
// TODO: Move this to the cash contract tests once mock services are further split up.
@ -42,7 +46,7 @@ class VaultWithCashTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testSerialization = SerializationEnvironmentRule(true)
lateinit var services: MockServices
lateinit var issuerServices: MockServices
val vaultService: VaultService get() = services.vaultService
@ -150,82 +154,74 @@ class VaultWithCashTest {
}
val backgroundExecutor = Executors.newFixedThreadPool(2)
val countDown = CountDownLatch(2)
// 1st tx that spends our money.
backgroundExecutor.submit {
val first = backgroundExecutor.fork {
database.transaction {
try {
val txn1Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB)
val ptxn1 = notaryServices.signInitialTransaction(txn1Builder)
val txn1 = services.addSignature(ptxn1, freshKey)
println("txn1: ${txn1.id} spent ${((txn1.tx.outputs[0].data) as Cash.State).amount}")
val unconsumedStates1 = vaultService.queryBy<Cash.State>()
val consumedStates1 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates1 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn1 states:
val txn1Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB)
val ptxn1 = notaryServices.signInitialTransaction(txn1Builder)
val txn1 = services.addSignature(ptxn1, freshKey)
println("txn1: ${txn1.id} spent ${((txn1.tx.outputs[0].data) as Cash.State).amount}")
val unconsumedStates1 = vaultService.queryBy<Cash.State>()
val consumedStates1 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates1 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn1 states:
UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1,
CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1,
LOCKED: ${lockedStates1.count()} : $lockedStates1
""")
services.recordTransactions(txn1)
println("txn1: Cash balance: ${services.getCashBalance(USD)}")
val unconsumedStates2 = vaultService.queryBy<Cash.State>()
val consumedStates2 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates2 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn1 states:
services.recordTransactions(txn1)
println("txn1: Cash balance: ${services.getCashBalance(USD)}")
val unconsumedStates2 = vaultService.queryBy<Cash.State>()
val consumedStates2 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates2 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn1 states:
UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2,
CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2,
LOCKED: ${lockedStates2.count()} : $lockedStates2
""")
txn1
} catch (e: Exception) {
println(e)
}
txn1
}
println("txn1 COMMITTED!")
countDown.countDown()
}
// 2nd tx that attempts to spend same money
backgroundExecutor.submit {
val second = backgroundExecutor.fork {
database.transaction {
try {
val txn2Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB)
val ptxn2 = notaryServices.signInitialTransaction(txn2Builder)
val txn2 = services.addSignature(ptxn2, freshKey)
println("txn2: ${txn2.id} spent ${((txn2.tx.outputs[0].data) as Cash.State).amount}")
val unconsumedStates1 = vaultService.queryBy<Cash.State>()
val consumedStates1 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates1 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn2 states:
val txn2Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB)
val ptxn2 = notaryServices.signInitialTransaction(txn2Builder)
val txn2 = services.addSignature(ptxn2, freshKey)
println("txn2: ${txn2.id} spent ${((txn2.tx.outputs[0].data) as Cash.State).amount}")
val unconsumedStates1 = vaultService.queryBy<Cash.State>()
val consumedStates1 = vaultService.queryBy<Cash.State>(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))
val lockedStates1 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn2 states:
UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1,
CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1,
LOCKED: ${lockedStates1.count()} : $lockedStates1
""")
services.recordTransactions(txn2)
println("txn2: Cash balance: ${services.getCashBalance(USD)}")
val unconsumedStates2 = vaultService.queryBy<Cash.State>()
val consumedStates2 = vaultService.queryBy<Cash.State>()
val lockedStates2 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn2 states:
services.recordTransactions(txn2)
println("txn2: Cash balance: ${services.getCashBalance(USD)}")
val unconsumedStates2 = vaultService.queryBy<Cash.State>()
val consumedStates2 = vaultService.queryBy<Cash.State>()
val lockedStates2 = vaultService.queryBy<Cash.State>(criteriaLocked).states
println("""txn2 states:
UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2,
CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2,
LOCKED: ${lockedStates2.count()} : $lockedStates2
""")
txn2
} catch (e: Exception) {
println(e)
}
txn2
}
println("txn2 COMMITTED!")
countDown.countDown()
}
countDown.await()
val both = listOf(first, second).transpose()
try {
both.getOrThrow()
fail("Expected insufficient balance.")
} catch (e: InsufficientBalanceException) {
assertEquals(0, e.suppressed.size) // One should succeed.
}
database.transaction {
println("Cash balance: ${services.getCashBalance(USD)}")
assertThat(services.getCashBalance(USD)).isIn(DOLLARS(20), DOLLARS(40))

View File

@ -203,8 +203,8 @@ class NodeInterestRatesTest {
}
@Test
fun `network tearoff`() {
val mockNet = MockNetwork(initialiseSerialization = false, cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs"))
fun `network tearoff`() = withoutTestSerialization {
val mockNet = MockNetwork(cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs"))
val aliceNode = mockNet.createPartyNode(ALICE.name)
val oracleNode = mockNet.createNode().apply {
internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java)

View File

@ -37,6 +37,7 @@ import net.corda.nodeapi.internal.addShutdownHook
import net.corda.testing.*
import net.corda.testing.common.internal.NetworkParametersCopier
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.setGlobalSerialization
import net.corda.testing.internal.ProcessUtilities
import net.corda.testing.node.ClusterSpec
import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO
@ -413,7 +414,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
coerce: (D) -> DI,
dsl: DI.() -> A
): A {
val serializationEnv = initialiseTestSerialization(initialiseSerialization)
val serializationEnv = setGlobalSerialization(initialiseSerialization)
val shutdownHook = addShutdownHook(driverDsl::shutdown)
try {
driverDsl.start()
@ -424,7 +425,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
} finally {
driverDsl.shutdown()
shutdownHook.cancel()
serializationEnv.resetTestSerialization()
serializationEnv.unset()
}
}
@ -451,7 +452,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
driverDslWrapper: (DriverDSL) -> D,
coerce: (D) -> DI, dsl: DI.() -> A
): A {
val serializationEnv = initialiseTestSerialization(initialiseSerialization)
val serializationEnv = setGlobalSerialization(initialiseSerialization)
val driverDsl = driverDslWrapper(
DriverDSL(
portAllocation = portAllocation,
@ -475,7 +476,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
} finally {
driverDsl.shutdown()
shutdownHook.cancel()
serializationEnv.resetTestSerialization()
serializationEnv.unset()
}
}

View File

@ -38,7 +38,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testSerialization = SerializationEnvironmentRule(true)
@Rule
@JvmField
val tempFolder = TemporaryFolder()
@ -63,16 +63,20 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
@After
fun stopAllNodes() {
val shutdownExecutor = Executors.newScheduledThreadPool(nodes.size)
nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow()
// Wait until ports are released
val portNotBoundChecks = nodes.flatMap {
listOf(
it.internals.configuration.p2pAddress.let { addressMustNotBeBoundFuture(shutdownExecutor, it) },
it.internals.configuration.rpcAddress?.let { addressMustNotBeBoundFuture(shutdownExecutor, it) }
)
}.filterNotNull()
nodes.clear()
portNotBoundChecks.transpose().getOrThrow()
try {
nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow()
// Wait until ports are released
val portNotBoundChecks = nodes.flatMap {
listOf(
it.internals.configuration.p2pAddress.let { addressMustNotBeBoundFuture(shutdownExecutor, it) },
it.internals.configuration.rpcAddress?.let { addressMustNotBeBoundFuture(shutdownExecutor, it) }
)
}.filterNotNull()
nodes.clear()
portNotBoundChecks.transpose().getOrThrow()
} finally {
shutdownExecutor.shutdown()
}
}
@JvmOverloads

View File

@ -39,7 +39,7 @@ import net.corda.node.utilities.ServiceIdentityGenerator
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.common.internal.NetworkParametersCopier
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.initialiseTestSerialization
import net.corda.testing.setGlobalSerialization
import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.testNodeConfiguration
@ -136,9 +136,8 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete
private val networkId = random63BitValue()
private val networkParameters: NetworkParametersCopier
private val _nodes = mutableListOf<MockNode>()
private val serializationEnv = initialiseTestSerialization(initialiseSerialization)
private val serializationEnv = setGlobalSerialization(initialiseSerialization)
private val sharedUserCount = AtomicInteger(0)
/** A read only view of the current set of executing nodes. */
val nodes: List<MockNode> get() = _nodes
@ -419,7 +418,7 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete
fun stopNodes() {
nodes.forEach { it.started?.dispose() }
serializationEnv.resetTestSerialization()
serializationEnv.unset()
}
// Test method to block until all scheduled activity, active flows

View File

@ -3,10 +3,7 @@ package net.corda.testing
import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.utilities.ByteSequence
import net.corda.core.serialization.internal.*
import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
@ -15,183 +12,84 @@ import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
class SerializationEnvironmentRule : TestRule {
lateinit var env: SerializationEnvironment
/** @param inheritable whether new threads inherit the environment, use sparingly. */
class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
val env: SerializationEnvironment = createTestSerializationEnv()
override fun apply(base: Statement, description: Description?) = object : Statement() {
override fun evaluate() = withTestSerialization {
env = it
override fun evaluate() = env.asContextEnv(inheritable) {
base.evaluate()
}
}
}
interface TestSerializationEnvironment : SerializationEnvironment {
fun resetTestSerialization()
interface GlobalSerializationEnvironment : SerializationEnvironment {
/** Unset this environment. */
fun unset()
}
fun <T> withTestSerialization(block: (SerializationEnvironment) -> T): T {
val env = initialiseTestSerializationImpl()
/** @param inheritable whether new threads inherit the environment, use sparingly. */
fun <T> withTestSerialization(inheritable: Boolean = false, callable: (SerializationEnvironment) -> T): T {
return createTestSerializationEnv().asContextEnv(inheritable, callable)
}
private fun <T> SerializationEnvironment.asContextEnv(inheritable: Boolean, callable: (SerializationEnvironment) -> T): T {
val property = if (inheritable) _inheritableContextSerializationEnv else _contextSerializationEnv
property.set(this)
try {
return block(env)
return callable(this)
} finally {
env.resetTestSerialization()
property.set(null)
}
}
/** @param armed true to init, false to do nothing and return a dummy env. */
fun initialiseTestSerialization(armed: Boolean): TestSerializationEnvironment {
/**
* For example your test class uses [SerializationEnvironmentRule] but you want to turn it off for one method.
* Use sparingly, ideally a test class shouldn't mix serialization init mechanisms.
*/
fun <T> withoutTestSerialization(callable: () -> T): T {
val (property, env) = listOf(_contextSerializationEnv, _inheritableContextSerializationEnv).map { Pair(it, it.get()) }.single { it.second != null }
property.set(null)
try {
return callable()
} finally {
property.set(env)
}
}
/**
* Should only be used by Driver and MockNode.
* @param armed true to install, false to do nothing and return a dummy env.
*/
fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
return if (armed) {
val env = initialiseTestSerializationImpl()
object : TestSerializationEnvironment, SerializationEnvironment by env {
override fun resetTestSerialization() = env.resetTestSerialization()
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() {
override fun unset() {
_globalSerializationEnv.set(null)
}
}.also {
_globalSerializationEnv.set(it)
}
} else {
rigorousMock<TestSerializationEnvironment>().also {
doNothing().whenever(it).resetTestSerialization()
rigorousMock<GlobalSerializationEnvironment>().also {
doNothing().whenever(it).unset()
}
}
}
private fun initialiseTestSerializationImpl() = SerializationDefaults.apply {
// Stop the CordaRPCClient from trying to setup the defaults as we're about to do it now
KryoClientSerializationScheme.isInitialised.set(true)
// Check that everything is configured for testing with mutable delegating instances.
try {
check(SERIALIZATION_FACTORY is TestSerializationFactory)
} catch (e: IllegalStateException) {
SERIALIZATION_FACTORY = TestSerializationFactory()
}
try {
check(P2P_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
P2P_CONTEXT = TestSerializationContext()
}
try {
check(RPC_SERVER_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
RPC_SERVER_CONTEXT = TestSerializationContext()
}
try {
check(RPC_CLIENT_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
RPC_CLIENT_CONTEXT = TestSerializationContext()
}
try {
check(STORAGE_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
STORAGE_CONTEXT = TestSerializationContext()
}
try {
check(CHECKPOINT_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
CHECKPOINT_CONTEXT = TestSerializationContext()
}
// Check that the previous test, if there was one, cleaned up after itself.
// IF YOU SEE THESE MESSAGES, THEN IT MEANS A TEST HAS NOT CALLED resetTestSerialization()
check((SERIALIZATION_FACTORY as TestSerializationFactory).delegate == null, { "Expected uninitialised serialization framework but found it set from: $SERIALIZATION_FACTORY" })
check((P2P_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $P2P_CONTEXT" })
check((RPC_SERVER_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_SERVER_CONTEXT" })
check((RPC_CLIENT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_CLIENT_CONTEXT" })
check((STORAGE_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $STORAGE_CONTEXT" })
check((CHECKPOINT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $CHECKPOINT_CONTEXT" })
// Now configure all the testing related delegates.
(SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme())
registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPClientSerializationScheme())
registerScheme(AMQPServerSerializationScheme())
}
(P2P_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT
(RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT
(RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_CLIENT_CONTEXT
(STORAGE_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT
(CHECKPOINT_CONTEXT as TestSerializationContext).delegate = KRYO_CHECKPOINT_CONTEXT
}
private fun createTestSerializationEnv() = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme())
registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPClientSerializationScheme())
registerScheme(AMQPServerSerializationScheme())
},
if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT,
KRYO_RPC_SERVER_CONTEXT,
KRYO_RPC_CLIENT_CONTEXT,
if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT)
private const val AMQP_ENABLE_PROP_NAME = "net.corda.testing.amqp.enable"
// TODO: Remove usages of this function when we fully switched to AMQP
private fun isAmqpEnabled(): Boolean = java.lang.Boolean.getBoolean(AMQP_ENABLE_PROP_NAME)
private fun SerializationDefaults.resetTestSerialization() {
(SERIALIZATION_FACTORY as TestSerializationFactory).delegate = null
(P2P_CONTEXT as TestSerializationContext).delegate = null
(RPC_SERVER_CONTEXT as TestSerializationContext).delegate = null
(RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = null
(STORAGE_CONTEXT as TestSerializationContext).delegate = null
(CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null
}
class TestSerializationFactory : SerializationFactory() {
var delegate: SerializationFactory? = null
set(value) {
field = value
stackTrace = Exception().stackTrace.asList()
}
private var stackTrace: List<StackTraceElement>? = null
override fun toString(): String = stackTrace?.joinToString("\n") ?: "null"
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
return delegate!!.deserialize(byteSequence, clazz, context)
}
override fun <T : Any> deserializeWithCompatibleContext(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): ObjectWithCompatibleContext<T> {
return delegate!!.deserializeWithCompatibleContext(byteSequence, clazz, context)
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
return delegate!!.serialize(obj, context)
}
}
class TestSerializationContext : SerializationContext {
var delegate: SerializationContext? = null
set(value) {
field = value
stackTrace = Exception().stackTrace.asList()
}
private var stackTrace: List<StackTraceElement>? = null
override fun toString(): String = stackTrace?.joinToString("\n") ?: "null"
override val preferredSerializationVersion: ByteSequence
get() = delegate!!.preferredSerializationVersion
override val deserializationClassLoader: ClassLoader
get() = delegate!!.deserializationClassLoader
override val whitelist: ClassWhitelist
get() = delegate!!.whitelist
override val properties: Map<Any, Any>
get() = delegate!!.properties
override val objectReferencesEnabled: Boolean
get() = delegate!!.objectReferencesEnabled
override val useCase: SerializationContext.UseCase
get() = delegate!!.useCase
override fun withProperty(property: Any, value: Any): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withProperty(property, value) }
}
override fun withoutReferences(): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withoutReferences() }
}
override fun withClassLoader(classLoader: ClassLoader): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withClassLoader(classLoader) }
}
override fun withWhitelisted(clazz: Class<*>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) }
}
override fun withPreferredSerializationVersion(versionHeader: VersionHeader): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) }
}
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withAttachmentsClassLoader(attachmentHashes) }
}
}

View File

@ -1,6 +1,7 @@
package net.corda.testing.messaging
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.ArtemisTcpTransport
@ -27,6 +28,7 @@ class SimpleMQClient(val target: NetworkHostAndPort,
val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
isBlockOnNonDurableSend = true
threadPoolMaxSize = 1
isUseGlobalPools = nodeSerializationEnv != null
}
sessionFactory = locator.createSessionFactory()
session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, locator.ackBatchSize)

View File

@ -5,7 +5,8 @@ import com.typesafe.config.ConfigFactory
import com.typesafe.config.ConfigParseOptions
import net.corda.core.internal.div
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug
@ -89,15 +90,16 @@ class Verifier {
}
private fun initialiseSerialization() {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoVerifierSerializationScheme)
registerScheme(AMQPVerifierSerializationScheme)
}
/**
* Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming
* request received, see use of [context] in [main] method.
*/
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT
nodeSerializationEnv = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoVerifierSerializationScheme)
registerScheme(AMQPVerifierSerializationScheme)
},
/**
* Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming
* request received, see use of [context] in [main] method.
*/
KRYO_P2P_CONTEXT)
}
}