diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 03b214c40d..a725671072 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -2430,19 +2430,13 @@ public static final class net.corda.core.serialization.SerializationContext$UseC public static net.corda.core.serialization.SerializationContext$UseCase valueOf(String) public static net.corda.core.serialization.SerializationContext$UseCase[] values() ## -public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object implements net.corda.core.serialization.internal.SerializationEnvironment - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT() - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getP2P_CONTEXT() - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT() - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT() - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY() - @org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT() - public void setCHECKPOINT_CONTEXT(net.corda.core.serialization.SerializationContext) - public void setP2P_CONTEXT(net.corda.core.serialization.SerializationContext) - public void setRPC_CLIENT_CONTEXT(net.corda.core.serialization.SerializationContext) - public void setRPC_SERVER_CONTEXT(net.corda.core.serialization.SerializationContext) - public void setSERIALIZATION_FACTORY(net.corda.core.serialization.SerializationFactory) - public void setSTORAGE_CONTEXT(net.corda.core.serialization.SerializationContext) +public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT() + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT() + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT() + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT() + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY() + @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT() public static final net.corda.core.serialization.SerializationDefaults INSTANCE ## public abstract class net.corda.core.serialization.SerializationFactory extends java.lang.Object diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 1ef330d3e1..7326a5a182 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -4,10 +4,10 @@ import net.corda.client.rpc.internal.KryoClientSerializationScheme import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.messaging.CordaRPCOps +import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ConnectionDirection -import net.corda.nodeapi.internal.serialization.AMQP_RPC_CLIENT_CONTEXT import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT import java.time.Duration @@ -71,8 +71,15 @@ class CordaRPCClient @JvmOverloads constructor( configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT ) { init { - // TODO: allow clients to have serialization factory etc injected and align with RPC protocol version? - KryoClientSerializationScheme.initialiseSerialization() + try { + effectiveSerializationEnv + } catch (e: IllegalStateException) { + try { + KryoClientSerializationScheme.initialiseSerialization() + } catch (e: IllegalStateException) { + // Race e.g. two of these constructed in parallel, ignore. + } + } } private val rpcClient = RPCClient( diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/KryoClientSerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/KryoClientSerializationScheme.kt index 039768185c..0e830b3a99 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/KryoClientSerializationScheme.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/KryoClientSerializationScheme.kt @@ -2,15 +2,17 @@ package net.corda.client.rpc.internal import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.ByteSequence -import net.corda.nodeapi.internal.serialization.* +import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT +import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT +import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.nodeapi.internal.serialization.kryo.AbstractKryoSerializationScheme import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 import net.corda.nodeapi.internal.serialization.kryo.RPCKryo -import java.util.concurrent.atomic.AtomicBoolean class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { @@ -29,25 +31,15 @@ class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() companion object { - val isInitialised = AtomicBoolean(false) + /** Call from main only. */ fun initialiseSerialization() { - if (!isInitialised.compareAndSet(false, true)) return - try { - SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoClientSerializationScheme()) - registerScheme(AMQPClientSerializationScheme()) - } - SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT - SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT - } catch (e: IllegalStateException) { - // Check that it's registered as we expect - val factory = SerializationDefaults.SERIALIZATION_FACTORY - val checkedFactory = factory as? SerializationFactoryImpl - ?: throw IllegalStateException("RPC client encountered conflicting configuration of serialization subsystem: $factory") - check(checkedFactory.alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) { - "RPC client encountered conflicting configuration of serialization subsystem." - } - } + nodeSerializationEnv = SerializationEnvironmentImpl( + SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme()) + registerScheme(AMQPClientSerializationScheme()) + }, + KRYO_P2P_CONTEXT, + rpcClientContext = KRYO_RPC_CLIENT_CONTEXT) } } } \ No newline at end of file diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index 773c511da9..9f925e540c 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -8,6 +8,7 @@ import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.RPCOps import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import net.corda.core.utilities.minutes @@ -110,6 +111,7 @@ class RPCClient( maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis() reconnectAttempts = rpcConfiguration.maxReconnectAttempts minLargeMessageSize = rpcConfiguration.maxFileSize + isUseGlobalPools = nodeSerializationEnv != null } val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext) diff --git a/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt b/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt new file mode 100644 index 0000000000..6020cdf40e --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/ToggleField.kt @@ -0,0 +1,66 @@ +package net.corda.core.internal + +import java.util.concurrent.atomic.AtomicReference +import kotlin.reflect.KProperty + +/** May go from null to non-null and vice-versa, and that's it. */ +abstract class ToggleField(val name: String) { + private val writeMutex = Any() // Protects the toggle logic only. + abstract fun get(): T? + fun set(value: T?) = synchronized(writeMutex) { + if (value != null) { + check(get() == null) { "$name already has a value." } + setImpl(value) + } else { + check(get() != null) { "$name is already null." } + clear() + } + } + + protected abstract fun setImpl(value: T) + protected abstract fun clear() + operator fun getValue(thisRef: Any?, property: KProperty<*>) = get() + operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T?) = set(value) +} + +class SimpleToggleField(name: String, private val once: Boolean = false) : ToggleField(name) { + private val holder = AtomicReference() // Force T? in API for safety. + override fun get() = holder.get() + override fun setImpl(value: T) = holder.set(value) + override fun clear() { + check(!once) { "Value of $name cannot be changed." } + holder.set(null) + } +} + +class ThreadLocalToggleField(name: String) : ToggleField(name) { + private val threadLocal = ThreadLocal() + override fun get() = threadLocal.get() + override fun setImpl(value: T) = threadLocal.set(value) + override fun clear() = threadLocal.remove() +} + +/** The named thread has leaked from a previous test. */ +class ThreadLeakException : RuntimeException("Leaked thread detected: ${Thread.currentThread().name}") + +class InheritableThreadLocalToggleField(name: String) : ToggleField(name) { + private class Holder(value: T) : AtomicReference(value) { + fun valueOrDeclareLeak() = get() ?: throw ThreadLeakException() + } + + private val threadLocal = object : InheritableThreadLocal?>() { + override fun childValue(holder: Holder?): Holder? { + // The Holder itself may be null due to prior events, a leak is not implied in that case: + holder?.valueOrDeclareLeak() // Fail fast. + return holder // What super does. + } + } + + override fun get() = threadLocal.get()?.valueOrDeclareLeak() + override fun setImpl(value: T) = threadLocal.set(Holder(value)) + override fun clear() = threadLocal.run { + val holder = get()!! + remove() + holder.set(null) // Threads that inherited the holder are now considered to have escaped from the test. + } +} diff --git a/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt b/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt deleted file mode 100644 index ae815be6ca..0000000000 --- a/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt +++ /dev/null @@ -1,18 +0,0 @@ -package net.corda.core.internal - -import kotlin.reflect.KProperty - -/** - * A write-once property to be used as delegate for Kotlin var properties. The expectation is that this is initialised - * prior to the spawning of any threads that may access it and so there's no need for it to be volatile. - */ -class WriteOnceProperty(private val defaultValue: T? = null) { - private var v: T? = defaultValue - - operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: throw IllegalStateException("Write-once property $property not set.") - - operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { - check(v == defaultValue || v === value) { "Cannot set write-once property $property more than once." } - v = value - } -} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index bd5a0cc2e5..5b5aa35a6f 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -2,8 +2,7 @@ package net.corda.core.serialization import net.corda.core.crypto.SecureHash import net.corda.core.crypto.sha256 -import net.corda.core.internal.WriteOnceProperty -import net.corda.core.serialization.internal.SerializationEnvironment +import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.sequence @@ -53,7 +52,7 @@ abstract class SerializationFactory { * A context to use as a default if you do not require a specially configured context. It will be the current context * if the use is somehow nested (see [currentContext]). */ - val defaultContext: SerializationContext get() = currentContext ?: SerializationDefaults.P2P_CONTEXT + val defaultContext: SerializationContext get() = currentContext ?: effectiveSerializationEnv.p2pContext private val _currentContext = ThreadLocal() @@ -90,7 +89,7 @@ abstract class SerializationFactory { /** * A default factory for serialization/deserialization, taking into account the [currentFactory] if set. */ - val defaultFactory: SerializationFactory get() = currentFactory ?: SerializationDefaults.SERIALIZATION_FACTORY + val defaultFactory: SerializationFactory get() = currentFactory ?: effectiveSerializationEnv.serializationFactory /** * If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization, @@ -173,13 +172,13 @@ interface SerializationContext { /** * Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client). */ -object SerializationDefaults : SerializationEnvironment { - override var SERIALIZATION_FACTORY: SerializationFactory by WriteOnceProperty() - override var P2P_CONTEXT: SerializationContext by WriteOnceProperty() - override var RPC_SERVER_CONTEXT: SerializationContext by WriteOnceProperty() - override var RPC_CLIENT_CONTEXT: SerializationContext by WriteOnceProperty() - override var STORAGE_CONTEXT: SerializationContext by WriteOnceProperty() - override var CHECKPOINT_CONTEXT: SerializationContext by WriteOnceProperty() +object SerializationDefaults { + val SERIALIZATION_FACTORY get() = effectiveSerializationEnv.serializationFactory + val P2P_CONTEXT get() = effectiveSerializationEnv.p2pContext + val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext + val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext + val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext + val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext } /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt index 9585551bd0..8c33e1b3a7 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt @@ -1,13 +1,55 @@ package net.corda.core.serialization.internal +import net.corda.core.internal.InheritableThreadLocalToggleField +import net.corda.core.internal.SimpleToggleField +import net.corda.core.internal.ThreadLocalToggleField +import net.corda.core.internal.VisibleForTesting import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationFactory interface SerializationEnvironment { - val SERIALIZATION_FACTORY: SerializationFactory - val P2P_CONTEXT: SerializationContext - val RPC_SERVER_CONTEXT: SerializationContext - val RPC_CLIENT_CONTEXT: SerializationContext - val STORAGE_CONTEXT: SerializationContext - val CHECKPOINT_CONTEXT: SerializationContext + val serializationFactory: SerializationFactory + val p2pContext: SerializationContext + val rpcServerContext: SerializationContext + val rpcClientContext: SerializationContext + val storageContext: SerializationContext + val checkpointContext: SerializationContext } + +class SerializationEnvironmentImpl( + override val serializationFactory: SerializationFactory, + override val p2pContext: SerializationContext, + rpcServerContext: SerializationContext? = null, + rpcClientContext: SerializationContext? = null, + storageContext: SerializationContext? = null, + checkpointContext: SerializationContext? = null) : SerializationEnvironment { + // Those that are passed in as null are never inited: + override lateinit var rpcServerContext: SerializationContext + override lateinit var rpcClientContext: SerializationContext + override lateinit var storageContext: SerializationContext + override lateinit var checkpointContext: SerializationContext + + init { + rpcServerContext?.let { this.rpcServerContext = it } + rpcClientContext?.let { this.rpcClientContext = it } + storageContext?.let { this.storageContext = it } + checkpointContext?.let { this.checkpointContext = it } + } +} + +private val _nodeSerializationEnv = SimpleToggleField("nodeSerializationEnv", true) +@VisibleForTesting +val _globalSerializationEnv = SimpleToggleField("globalSerializationEnv") +@VisibleForTesting +val _contextSerializationEnv = ThreadLocalToggleField("contextSerializationEnv") +@VisibleForTesting +val _inheritableContextSerializationEnv = InheritableThreadLocalToggleField("inheritableContextSerializationEnv") +private val serializationEnvProperties = listOf(_nodeSerializationEnv, _globalSerializationEnv, _contextSerializationEnv, _inheritableContextSerializationEnv) +val effectiveSerializationEnv: SerializationEnvironment + get() = serializationEnvProperties.map { Pair(it, it.get()) }.filter { it.second != null }.run { + singleOrNull()?.run { + second!! + } ?: throw IllegalStateException("Expected exactly 1 of {${serializationEnvProperties.joinToString(", ") { it.name }}} but got: {${joinToString(", ") { it.first.name }}}") + } +/** Should be set once in main. */ +var nodeSerializationEnv by _nodeSerializationEnv diff --git a/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt b/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt new file mode 100644 index 0000000000..0ada6af5a2 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/ToggleFieldTest.kt @@ -0,0 +1,125 @@ +package net.corda.core.internal + +import net.corda.core.internal.concurrent.fork +import net.corda.core.utilities.getOrThrow +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Test +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNull + +private fun withSingleThreadExecutor(callable: ExecutorService.() -> T) = Executors.newSingleThreadExecutor().run { + try { + fork {}.getOrThrow() // Start the thread. + callable() + } finally { + shutdown() + while (!awaitTermination(1, TimeUnit.SECONDS)) { + // Do nothing. + } + } +} + +class ToggleFieldTest { + @Test + fun `toggle is enforced`() { + listOf(SimpleToggleField("simple"), ThreadLocalToggleField("local"), InheritableThreadLocalToggleField("inheritable")).forEach { field -> + assertNull(field.get()) + assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java) + field.set("hello") + assertEquals("hello", field.get()) + assertThatThrownBy { field.set("world") }.isInstanceOf(IllegalStateException::class.java) + assertEquals("hello", field.get()) + assertThatThrownBy { field.set("hello") }.isInstanceOf(IllegalStateException::class.java) + field.set(null) + assertNull(field.get()) + } + } + + @Test + fun `write-at-most-once field works`() { + val field = SimpleToggleField("field", true) + assertNull(field.get()) + assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java) + field.set("finalValue") + assertEquals("finalValue", field.get()) + listOf("otherValue", "finalValue", null).forEach { value -> + assertThatThrownBy { field.set(value) }.isInstanceOf(IllegalStateException::class.java) + assertEquals("finalValue", field.get()) + } + } + + @Test + fun `thread local works`() { + val field = ThreadLocalToggleField("field") + assertNull(field.get()) + field.set("hello") + assertEquals("hello", field.get()) + withSingleThreadExecutor { + assertNull(fork(field::get).getOrThrow()) + } + field.set(null) + assertNull(field.get()) + } + + @Test + fun `inheritable thread local works`() { + val field = InheritableThreadLocalToggleField("field") + assertNull(field.get()) + field.set("hello") + assertEquals("hello", field.get()) + withSingleThreadExecutor { + assertEquals("hello", fork(field::get).getOrThrow()) + } + field.set(null) + assertNull(field.get()) + } + + @Test + fun `existing threads do not inherit`() { + val field = InheritableThreadLocalToggleField("field") + withSingleThreadExecutor { + field.set("hello") + assertEquals("hello", field.get()) + assertNull(fork(field::get).getOrThrow()) + } + } + + @Test + fun `inherited values are poisoned on clear`() { + val field = InheritableThreadLocalToggleField("field") + field.set("hello") + withSingleThreadExecutor { + assertEquals("hello", fork(field::get).getOrThrow()) + val threadName = fork { Thread.currentThread().name }.getOrThrow() + listOf(null, "world").forEach { value -> + field.set(value) + assertEquals(value, field.get()) + val future = fork(field::get) + assertThatThrownBy { future.getOrThrow() } + .isInstanceOf(ThreadLeakException::class.java) + .hasMessageContaining(threadName) + } + } + withSingleThreadExecutor { + assertEquals("world", fork(field::get).getOrThrow()) + } + } + + @Test + fun `leaked thread is detected as soon as it tries to create another`() { + val field = InheritableThreadLocalToggleField("field") + field.set("hello") + withSingleThreadExecutor { + assertEquals("hello", fork(field::get).getOrThrow()) + field.set(null) // The executor thread is now considered leaked. + val threadName = fork { Thread.currentThread().name }.getOrThrow() + val future = fork(::Thread) + assertThatThrownBy { future.getOrThrow() } + .isInstanceOf(ThreadLeakException::class.java) + .hasMessageContaining(threadName) + } + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt index ba2dd298d7..947560c257 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisTcpTransport.kt @@ -1,6 +1,7 @@ package net.corda.nodeapi import net.corda.core.identity.CordaX500Name +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.config.SSLConfiguration import org.apache.activemq.artemis.api.core.TransportConfiguration @@ -48,7 +49,8 @@ class ArtemisTcpTransport { // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // It does not use AMQP messages for its own messages e.g. topology and heartbeats. // TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications. - TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP" + TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP", + TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null) ) if (config != null && enableSSL) { diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java index 0cb878ff49..5e7dffce8d 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java @@ -25,7 +25,7 @@ public final class ForbiddenLambdaSerializationTests { @Before public void setup() { - factory = testSerialization.env.getSERIALIZATION_FACTORY(); + factory = testSerialization.getEnv().getSerializationFactory(); } @Test diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java index 62a71ff630..6e25a3f53d 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java @@ -25,7 +25,7 @@ public final class LambdaCheckpointSerializationTest { @Before public void setup() { - factory = testSerialization.env.getSERIALIZATION_FACTORY(); + factory = testSerialization.getEnv().getSerializationFactory(); context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint); } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ContractAttachmentSerializerTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ContractAttachmentSerializerTest.kt index 9b7dda7c79..0f9d2dd116 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ContractAttachmentSerializerTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ContractAttachmentSerializerTest.kt @@ -27,9 +27,8 @@ class ContractAttachmentSerializerTest { @Before fun setup() { - factory = testSerialization.env.SERIALIZATION_FACTORY - context = testSerialization.env.CHECKPOINT_CONTEXT - + factory = testSerialization.env.serializationFactory + context = testSerialization.env.checkpointContext contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices)) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt index bcca06f6de..a51f5934e1 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt @@ -8,7 +8,6 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.nodeapi.internal.serialization.kryo.CordaKryo import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 -import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.rigorousMock import net.corda.testing.SerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThat @@ -26,8 +25,8 @@ class SerializationTokenTest { @Before fun setup() { - factory = testSerialization.env.SERIALIZATION_FACTORY - context = testSerialization.env.CHECKPOINT_CONTEXT.withWhitelisted(SingletonSerializationToken::class.java) + factory = testSerialization.env.serializationFactory + context = testSerialization.env.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java) } // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized diff --git a/node/src/integration-test/kotlin/net/corda/node/services/AttachmentLoadingTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/AttachmentLoadingTests.kt index 0b233b9fb3..0b09f55804 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/AttachmentLoadingTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/AttachmentLoadingTests.kt @@ -17,26 +17,21 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappProviderImpl +import net.corda.testing.* import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.SerializationEnvironmentRule import net.corda.testing.driver.DriverDSLExposedInterface import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.driver import net.corda.testing.node.MockServices import org.junit.Assert.assertEquals import org.junit.Before -import org.junit.Rule import org.junit.Test import java.net.URLClassLoader import java.nio.file.Files import kotlin.test.assertFailsWith class AttachmentLoadingTests { - @Rule - @JvmField - val testSerialization = SerializationEnvironmentRule() - private class Services : MockServices() { private val provider = CordappProviderImpl(CordappLoader.createDevMode(listOf(isolatedJAR)), attachments) private val cordapp get() = provider.cordapps.first() @@ -83,7 +78,7 @@ class AttachmentLoadingTests { } @Test - fun `test a wire transaction has loaded the correct attachment`() { + fun `test a wire transaction has loaded the correct attachment`() = withTestSerialization { val appClassLoader = services.appContext.classLoader val contractClass = appClassLoader.loadClass(ISOLATED_CONTRACT_ID).asSubclass(Contract::class.java) val generateInitialMethod = contractClass.getDeclaredMethod("generateInitial", PartyAndReference::class.java, Integer.TYPE, Party::class.java) @@ -101,7 +96,7 @@ class AttachmentLoadingTests { @Test fun `test that attachments retrieved over the network are not used for code`() { - driver(initialiseSerialization = false) { + driver { installIsolatedCordappTo(bankAName) val (bankA, bankB) = createTwoNodes() assertFailsWith("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") { @@ -112,7 +107,7 @@ class AttachmentLoadingTests { @Test fun `tests that if the attachment is loaded on both sides already that a flow can run`() { - driver(initialiseSerialization = false) { + driver { installIsolatedCordappTo(bankAName) installIsolatedCordappTo(bankBName) val (bankA, bankB) = createTwoNodes() diff --git a/node/src/integration-test/kotlin/net/corda/node/services/network/NodeInfoWatcherTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/network/NodeInfoWatcherTest.kt index 897504d62d..d1bf50d6aa 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/network/NodeInfoWatcherTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/network/NodeInfoWatcherTest.kt @@ -10,11 +10,13 @@ import net.corda.nodeapi.NodeInfoFilesCopier import net.corda.testing.ALICE import net.corda.testing.ALICE_KEY import net.corda.testing.getTestPartyAndCertificate -import net.corda.testing.internal.NodeBasedTest +import net.corda.testing.* import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.contentOf import org.junit.Before +import org.junit.Rule import org.junit.Test +import org.junit.rules.TemporaryFolder import rx.observers.TestSubscriber import rx.schedulers.TestScheduler import java.nio.file.Path @@ -22,11 +24,17 @@ import java.util.concurrent.TimeUnit import kotlin.test.assertEquals import kotlin.test.assertTrue -class NodeInfoWatcherTest : NodeBasedTest() { +class NodeInfoWatcherTest { companion object { val nodeInfo = NodeInfo(listOf(), listOf(getTestPartyAndCertificate(ALICE)), 0, 0) } + @Rule + @JvmField + val testSerialization = SerializationEnvironmentRule() + @Rule + @JvmField + val tempFolder = TemporaryFolder() private lateinit var nodeInfoPath: Path private val scheduler = TestScheduler() private val testSubscriber = TestSubscriber() diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index bb617bb55a..0761169df7 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -9,9 +9,10 @@ import net.corda.core.internal.concurrent.thenMatch import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.RPCOps import net.corda.core.node.ServiceHub -import net.corda.core.serialization.SerializationDefaults import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor +import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.node.VersionInfo import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.serialization.KryoServerSerializationScheme @@ -25,7 +26,6 @@ import net.corda.node.services.messaging.NodeMessagingClient import net.corda.node.utilities.AddressUtils import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.DemoClock -import net.corda.nodeapi.ArtemisMessagingComponent import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.addShutdownHook import net.corda.nodeapi.internal.serialization.* @@ -274,14 +274,15 @@ open class Node(configuration: NodeConfiguration, private fun initialiseSerialization() { val classloader = cordappLoader.appClassLoader - SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoServerSerializationScheme()) - registerScheme(AMQPServerSerializationScheme()) - } - SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT.withClassLoader(classloader) - SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader) - SerializationDefaults.STORAGE_CONTEXT = KRYO_STORAGE_CONTEXT.withClassLoader(classloader) - SerializationDefaults.CHECKPOINT_CONTEXT = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) + nodeSerializationEnv = SerializationEnvironmentImpl( + SerializationFactoryImpl().apply { + registerScheme(KryoServerSerializationScheme()) + registerScheme(AMQPServerSerializationScheme()) + }, + KRYO_P2P_CONTEXT.withClassLoader(classloader), + rpcServerContext = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader), + storageContext = KRYO_STORAGE_CONTEXT.withClassLoader(classloader), + checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)) } /** Starts a blocking event loop for message dispatch. */ diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt index 39becee90a..48b3095282 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt @@ -11,6 +11,7 @@ import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.TransactionVerifierService import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.deserialize +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.serialize import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort @@ -217,6 +218,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, locator.connectionTTL = -1 locator.clientFailureCheckPeriod = -1 locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE + locator.isUseGlobalPools = nodeSerializationEnv != null sessionFactory = locator.createSessionFactory() // Login using the node username. The broker will authentiate us as its node (as opposed to another peer) diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 9bce320c38..c53c8b750d 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -60,7 +60,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(true) private val realClock: Clock = Clock.systemUTC() private val stoppedClock: Clock = Clock.fixed(realClock.instant(), realClock.zone) private val testClock = TestClock(stoppedClock) diff --git a/node/src/test/kotlin/net/corda/node/services/network/HTTPNetworkMapClientTest.kt b/node/src/test/kotlin/net/corda/node/services/network/HTTPNetworkMapClientTest.kt index f2b3bd4a70..25a4e77855 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/HTTPNetworkMapClientTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/HTTPNetworkMapClientTest.kt @@ -41,7 +41,7 @@ import kotlin.test.assertEquals class HTTPNetworkMapClientTest { @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(true) private lateinit var server: Server private lateinit var networkMapClient: NetworkMapClient diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt index f2ca28c8ec..e95243479a 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt @@ -30,7 +30,7 @@ class DistributedImmutableMapTests { @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(true) lateinit var cluster: List lateinit var transaction: DatabaseTransaction private val databases: MutableList = mutableListOf() diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt index dfc4db15e1..5ed6fc323e 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt @@ -1,9 +1,12 @@ package net.corda.node.services.vault import net.corda.core.contracts.ContractState +import net.corda.core.contracts.InsufficientBalanceException import net.corda.core.contracts.LinearState import net.corda.core.contracts.UniqueIdentifier import net.corda.core.identity.AnonymousParty +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.transpose import net.corda.core.internal.packageName import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultService @@ -11,6 +14,7 @@ import net.corda.core.node.services.queryBy import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow import net.corda.finance.* import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.DUMMY_CASH_ISSUER @@ -29,9 +33,9 @@ import org.junit.Before import org.junit.Rule import org.junit.Test import java.util.* -import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors import kotlin.test.assertEquals +import kotlin.test.fail // TODO: Move this to the cash contract tests once mock services are further split up. @@ -42,7 +46,7 @@ class VaultWithCashTest { @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(true) lateinit var services: MockServices lateinit var issuerServices: MockServices val vaultService: VaultService get() = services.vaultService @@ -150,82 +154,74 @@ class VaultWithCashTest { } val backgroundExecutor = Executors.newFixedThreadPool(2) - val countDown = CountDownLatch(2) - // 1st tx that spends our money. - backgroundExecutor.submit { + val first = backgroundExecutor.fork { database.transaction { - try { - val txn1Builder = TransactionBuilder(DUMMY_NOTARY) - Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB) - val ptxn1 = notaryServices.signInitialTransaction(txn1Builder) - val txn1 = services.addSignature(ptxn1, freshKey) - println("txn1: ${txn1.id} spent ${((txn1.tx.outputs[0].data) as Cash.State).amount}") - val unconsumedStates1 = vaultService.queryBy() - val consumedStates1 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) - val lockedStates1 = vaultService.queryBy(criteriaLocked).states - println("""txn1 states: + val txn1Builder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB) + val ptxn1 = notaryServices.signInitialTransaction(txn1Builder) + val txn1 = services.addSignature(ptxn1, freshKey) + println("txn1: ${txn1.id} spent ${((txn1.tx.outputs[0].data) as Cash.State).amount}") + val unconsumedStates1 = vaultService.queryBy() + val consumedStates1 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates1 = vaultService.queryBy(criteriaLocked).states + println("""txn1 states: UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1, CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1, LOCKED: ${lockedStates1.count()} : $lockedStates1 """) - services.recordTransactions(txn1) - println("txn1: Cash balance: ${services.getCashBalance(USD)}") - val unconsumedStates2 = vaultService.queryBy() - val consumedStates2 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) - val lockedStates2 = vaultService.queryBy(criteriaLocked).states - println("""txn1 states: + services.recordTransactions(txn1) + println("txn1: Cash balance: ${services.getCashBalance(USD)}") + val unconsumedStates2 = vaultService.queryBy() + val consumedStates2 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates2 = vaultService.queryBy(criteriaLocked).states + println("""txn1 states: UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2, CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2, LOCKED: ${lockedStates2.count()} : $lockedStates2 """) - txn1 - } catch (e: Exception) { - println(e) - } + txn1 } println("txn1 COMMITTED!") - countDown.countDown() } // 2nd tx that attempts to spend same money - backgroundExecutor.submit { + val second = backgroundExecutor.fork { database.transaction { - try { - val txn2Builder = TransactionBuilder(DUMMY_NOTARY) - Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB) - val ptxn2 = notaryServices.signInitialTransaction(txn2Builder) - val txn2 = services.addSignature(ptxn2, freshKey) - println("txn2: ${txn2.id} spent ${((txn2.tx.outputs[0].data) as Cash.State).amount}") - val unconsumedStates1 = vaultService.queryBy() - val consumedStates1 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) - val lockedStates1 = vaultService.queryBy(criteriaLocked).states - println("""txn2 states: + val txn2Builder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB) + val ptxn2 = notaryServices.signInitialTransaction(txn2Builder) + val txn2 = services.addSignature(ptxn2, freshKey) + println("txn2: ${txn2.id} spent ${((txn2.tx.outputs[0].data) as Cash.State).amount}") + val unconsumedStates1 = vaultService.queryBy() + val consumedStates1 = vaultService.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates1 = vaultService.queryBy(criteriaLocked).states + println("""txn2 states: UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1, CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1, LOCKED: ${lockedStates1.count()} : $lockedStates1 """) - services.recordTransactions(txn2) - println("txn2: Cash balance: ${services.getCashBalance(USD)}") - val unconsumedStates2 = vaultService.queryBy() - val consumedStates2 = vaultService.queryBy() - val lockedStates2 = vaultService.queryBy(criteriaLocked).states - println("""txn2 states: + services.recordTransactions(txn2) + println("txn2: Cash balance: ${services.getCashBalance(USD)}") + val unconsumedStates2 = vaultService.queryBy() + val consumedStates2 = vaultService.queryBy() + val lockedStates2 = vaultService.queryBy(criteriaLocked).states + println("""txn2 states: UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2, CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2, LOCKED: ${lockedStates2.count()} : $lockedStates2 """) - txn2 - } catch (e: Exception) { - println(e) - } + txn2 } println("txn2 COMMITTED!") - - countDown.countDown() } - - countDown.await() + val both = listOf(first, second).transpose() + try { + both.getOrThrow() + fail("Expected insufficient balance.") + } catch (e: InsufficientBalanceException) { + assertEquals(0, e.suppressed.size) // One should succeed. + } database.transaction { println("Cash balance: ${services.getCashBalance(USD)}") assertThat(services.getCashBalance(USD)).isIn(DOLLARS(20), DOLLARS(40)) diff --git a/samples/irs-demo/cordapp/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt b/samples/irs-demo/cordapp/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt index 193eb8ec08..96caf92fff 100644 --- a/samples/irs-demo/cordapp/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt +++ b/samples/irs-demo/cordapp/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt @@ -203,8 +203,8 @@ class NodeInterestRatesTest { } @Test - fun `network tearoff`() { - val mockNet = MockNetwork(initialiseSerialization = false, cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs")) + fun `network tearoff`() = withoutTestSerialization { + val mockNet = MockNetwork(cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs")) val aliceNode = mockNet.createPartyNode(ALICE.name) val oracleNode = mockNet.createNode().apply { internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt index 6fad6792ad..605a4deba0 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/driver/Driver.kt @@ -37,6 +37,7 @@ import net.corda.nodeapi.internal.addShutdownHook import net.corda.testing.* import net.corda.testing.common.internal.NetworkParametersCopier import net.corda.testing.common.internal.testNetworkParameters +import net.corda.testing.setGlobalSerialization import net.corda.testing.internal.ProcessUtilities import net.corda.testing.node.ClusterSpec import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO @@ -413,7 +414,7 @@ fun genericD coerce: (D) -> DI, dsl: DI.() -> A ): A { - val serializationEnv = initialiseTestSerialization(initialiseSerialization) + val serializationEnv = setGlobalSerialization(initialiseSerialization) val shutdownHook = addShutdownHook(driverDsl::shutdown) try { driverDsl.start() @@ -424,7 +425,7 @@ fun genericD } finally { driverDsl.shutdown() shutdownHook.cancel() - serializationEnv.resetTestSerialization() + serializationEnv.unset() } } @@ -451,7 +452,7 @@ fun genericD driverDslWrapper: (DriverDSL) -> D, coerce: (D) -> DI, dsl: DI.() -> A ): A { - val serializationEnv = initialiseTestSerialization(initialiseSerialization) + val serializationEnv = setGlobalSerialization(initialiseSerialization) val driverDsl = driverDslWrapper( DriverDSL( portAllocation = portAllocation, @@ -475,7 +476,7 @@ fun genericD } finally { driverDsl.shutdown() shutdownHook.cancel() - serializationEnv.resetTestSerialization() + serializationEnv.unset() } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/internal/NodeBasedTest.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/internal/NodeBasedTest.kt index 3c0c0da6e0..8a107a0fe6 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/internal/NodeBasedTest.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/internal/NodeBasedTest.kt @@ -38,7 +38,7 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testSerialization = SerializationEnvironmentRule(true) @Rule @JvmField val tempFolder = TemporaryFolder() @@ -63,16 +63,20 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi @After fun stopAllNodes() { val shutdownExecutor = Executors.newScheduledThreadPool(nodes.size) - nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow() - // Wait until ports are released - val portNotBoundChecks = nodes.flatMap { - listOf( - it.internals.configuration.p2pAddress.let { addressMustNotBeBoundFuture(shutdownExecutor, it) }, - it.internals.configuration.rpcAddress?.let { addressMustNotBeBoundFuture(shutdownExecutor, it) } - ) - }.filterNotNull() - nodes.clear() - portNotBoundChecks.transpose().getOrThrow() + try { + nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow() + // Wait until ports are released + val portNotBoundChecks = nodes.flatMap { + listOf( + it.internals.configuration.p2pAddress.let { addressMustNotBeBoundFuture(shutdownExecutor, it) }, + it.internals.configuration.rpcAddress?.let { addressMustNotBeBoundFuture(shutdownExecutor, it) } + ) + }.filterNotNull() + nodes.clear() + portNotBoundChecks.transpose().getOrThrow() + } finally { + shutdownExecutor.shutdown() + } } @JvmOverloads diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt index a44afbeb79..38f1a9553a 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -39,7 +39,7 @@ import net.corda.node.utilities.ServiceIdentityGenerator import net.corda.testing.DUMMY_NOTARY import net.corda.testing.common.internal.NetworkParametersCopier import net.corda.testing.common.internal.testNetworkParameters -import net.corda.testing.initialiseTestSerialization +import net.corda.testing.setGlobalSerialization import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.testNodeConfiguration @@ -136,9 +136,8 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete private val networkId = random63BitValue() private val networkParameters: NetworkParametersCopier private val _nodes = mutableListOf() - private val serializationEnv = initialiseTestSerialization(initialiseSerialization) + private val serializationEnv = setGlobalSerialization(initialiseSerialization) private val sharedUserCount = AtomicInteger(0) - /** A read only view of the current set of executing nodes. */ val nodes: List get() = _nodes @@ -419,7 +418,7 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete fun stopNodes() { nodes.forEach { it.started?.dispose() } - serializationEnv.resetTestSerialization() + serializationEnv.unset() } // Test method to block until all scheduled activity, active flows diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt index 4738175ae4..57139d6250 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -3,10 +3,7 @@ package net.corda.testing import com.nhaarman.mockito_kotlin.doNothing import com.nhaarman.mockito_kotlin.whenever import net.corda.client.rpc.internal.KryoClientSerializationScheme -import net.corda.core.crypto.SecureHash -import net.corda.core.serialization.* -import net.corda.core.serialization.internal.SerializationEnvironment -import net.corda.core.utilities.ByteSequence +import net.corda.core.serialization.internal.* import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme @@ -15,183 +12,84 @@ import org.junit.rules.TestRule import org.junit.runner.Description import org.junit.runners.model.Statement -class SerializationEnvironmentRule : TestRule { - lateinit var env: SerializationEnvironment +/** @param inheritable whether new threads inherit the environment, use sparingly. */ +class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule { + val env: SerializationEnvironment = createTestSerializationEnv() override fun apply(base: Statement, description: Description?) = object : Statement() { - override fun evaluate() = withTestSerialization { - env = it + override fun evaluate() = env.asContextEnv(inheritable) { base.evaluate() } } } -interface TestSerializationEnvironment : SerializationEnvironment { - fun resetTestSerialization() +interface GlobalSerializationEnvironment : SerializationEnvironment { + /** Unset this environment. */ + fun unset() } -fun withTestSerialization(block: (SerializationEnvironment) -> T): T { - val env = initialiseTestSerializationImpl() +/** @param inheritable whether new threads inherit the environment, use sparingly. */ +fun withTestSerialization(inheritable: Boolean = false, callable: (SerializationEnvironment) -> T): T { + return createTestSerializationEnv().asContextEnv(inheritable, callable) +} + +private fun SerializationEnvironment.asContextEnv(inheritable: Boolean, callable: (SerializationEnvironment) -> T): T { + val property = if (inheritable) _inheritableContextSerializationEnv else _contextSerializationEnv + property.set(this) try { - return block(env) + return callable(this) } finally { - env.resetTestSerialization() + property.set(null) } } -/** @param armed true to init, false to do nothing and return a dummy env. */ -fun initialiseTestSerialization(armed: Boolean): TestSerializationEnvironment { +/** + * For example your test class uses [SerializationEnvironmentRule] but you want to turn it off for one method. + * Use sparingly, ideally a test class shouldn't mix serialization init mechanisms. + */ +fun withoutTestSerialization(callable: () -> T): T { + val (property, env) = listOf(_contextSerializationEnv, _inheritableContextSerializationEnv).map { Pair(it, it.get()) }.single { it.second != null } + property.set(null) + try { + return callable() + } finally { + property.set(env) + } +} + +/** + * Should only be used by Driver and MockNode. + * @param armed true to install, false to do nothing and return a dummy env. + */ +fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment { return if (armed) { - val env = initialiseTestSerializationImpl() - object : TestSerializationEnvironment, SerializationEnvironment by env { - override fun resetTestSerialization() = env.resetTestSerialization() + object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() { + override fun unset() { + _globalSerializationEnv.set(null) + } + }.also { + _globalSerializationEnv.set(it) } } else { - rigorousMock().also { - doNothing().whenever(it).resetTestSerialization() + rigorousMock().also { + doNothing().whenever(it).unset() } } } -private fun initialiseTestSerializationImpl() = SerializationDefaults.apply { - // Stop the CordaRPCClient from trying to setup the defaults as we're about to do it now - KryoClientSerializationScheme.isInitialised.set(true) - // Check that everything is configured for testing with mutable delegating instances. - try { - check(SERIALIZATION_FACTORY is TestSerializationFactory) - } catch (e: IllegalStateException) { - SERIALIZATION_FACTORY = TestSerializationFactory() - } - try { - check(P2P_CONTEXT is TestSerializationContext) - } catch (e: IllegalStateException) { - P2P_CONTEXT = TestSerializationContext() - } - try { - check(RPC_SERVER_CONTEXT is TestSerializationContext) - } catch (e: IllegalStateException) { - RPC_SERVER_CONTEXT = TestSerializationContext() - } - try { - check(RPC_CLIENT_CONTEXT is TestSerializationContext) - } catch (e: IllegalStateException) { - RPC_CLIENT_CONTEXT = TestSerializationContext() - } - try { - check(STORAGE_CONTEXT is TestSerializationContext) - } catch (e: IllegalStateException) { - STORAGE_CONTEXT = TestSerializationContext() - } - try { - check(CHECKPOINT_CONTEXT is TestSerializationContext) - } catch (e: IllegalStateException) { - CHECKPOINT_CONTEXT = TestSerializationContext() - } - - // Check that the previous test, if there was one, cleaned up after itself. - // IF YOU SEE THESE MESSAGES, THEN IT MEANS A TEST HAS NOT CALLED resetTestSerialization() - check((SERIALIZATION_FACTORY as TestSerializationFactory).delegate == null, { "Expected uninitialised serialization framework but found it set from: $SERIALIZATION_FACTORY" }) - check((P2P_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $P2P_CONTEXT" }) - check((RPC_SERVER_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_SERVER_CONTEXT" }) - check((RPC_CLIENT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_CLIENT_CONTEXT" }) - check((STORAGE_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $STORAGE_CONTEXT" }) - check((CHECKPOINT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $CHECKPOINT_CONTEXT" }) - - // Now configure all the testing related delegates. - (SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { - registerScheme(KryoClientSerializationScheme()) - registerScheme(KryoServerSerializationScheme()) - registerScheme(AMQPClientSerializationScheme()) - registerScheme(AMQPServerSerializationScheme()) - } - - (P2P_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT - (RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT - (RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_CLIENT_CONTEXT - (STORAGE_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT - (CHECKPOINT_CONTEXT as TestSerializationContext).delegate = KRYO_CHECKPOINT_CONTEXT -} +private fun createTestSerializationEnv() = SerializationEnvironmentImpl( + SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme()) + registerScheme(KryoServerSerializationScheme()) + registerScheme(AMQPClientSerializationScheme()) + registerScheme(AMQPServerSerializationScheme()) + }, + if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT, + KRYO_RPC_SERVER_CONTEXT, + KRYO_RPC_CLIENT_CONTEXT, + if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT, + KRYO_CHECKPOINT_CONTEXT) private const val AMQP_ENABLE_PROP_NAME = "net.corda.testing.amqp.enable" // TODO: Remove usages of this function when we fully switched to AMQP private fun isAmqpEnabled(): Boolean = java.lang.Boolean.getBoolean(AMQP_ENABLE_PROP_NAME) - -private fun SerializationDefaults.resetTestSerialization() { - (SERIALIZATION_FACTORY as TestSerializationFactory).delegate = null - (P2P_CONTEXT as TestSerializationContext).delegate = null - (RPC_SERVER_CONTEXT as TestSerializationContext).delegate = null - (RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = null - (STORAGE_CONTEXT as TestSerializationContext).delegate = null - (CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null -} - -class TestSerializationFactory : SerializationFactory() { - var delegate: SerializationFactory? = null - set(value) { - field = value - stackTrace = Exception().stackTrace.asList() - } - private var stackTrace: List? = null - - override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" - - override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { - return delegate!!.deserialize(byteSequence, clazz, context) - } - - override fun deserializeWithCompatibleContext(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): ObjectWithCompatibleContext { - return delegate!!.deserializeWithCompatibleContext(byteSequence, clazz, context) - } - - override fun serialize(obj: T, context: SerializationContext): SerializedBytes { - return delegate!!.serialize(obj, context) - } -} - -class TestSerializationContext : SerializationContext { - var delegate: SerializationContext? = null - set(value) { - field = value - stackTrace = Exception().stackTrace.asList() - } - private var stackTrace: List? = null - - override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" - - override val preferredSerializationVersion: ByteSequence - get() = delegate!!.preferredSerializationVersion - override val deserializationClassLoader: ClassLoader - get() = delegate!!.deserializationClassLoader - override val whitelist: ClassWhitelist - get() = delegate!!.whitelist - override val properties: Map - get() = delegate!!.properties - override val objectReferencesEnabled: Boolean - get() = delegate!!.objectReferencesEnabled - override val useCase: SerializationContext.UseCase - get() = delegate!!.useCase - - override fun withProperty(property: Any, value: Any): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withProperty(property, value) } - } - - override fun withoutReferences(): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withoutReferences() } - } - - override fun withClassLoader(classLoader: ClassLoader): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withClassLoader(classLoader) } - } - - override fun withWhitelisted(clazz: Class<*>): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) } - } - - override fun withPreferredSerializationVersion(versionHeader: VersionHeader): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) } - } - - override fun withAttachmentsClassLoader(attachmentHashes: List): SerializationContext { - return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withAttachmentsClassLoader(attachmentHashes) } - } -} diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/messaging/SimpleMQClient.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/messaging/SimpleMQClient.kt index 15115b67fd..6109baffc8 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/messaging/SimpleMQClient.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/messaging/SimpleMQClient.kt @@ -1,6 +1,7 @@ package net.corda.testing.messaging import net.corda.core.identity.CordaX500Name +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.ArtemisMessagingComponent import net.corda.nodeapi.ArtemisTcpTransport @@ -27,6 +28,7 @@ class SimpleMQClient(val target: NetworkHostAndPort, val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { isBlockOnNonDurableSend = true threadPoolMaxSize = 1 + isUseGlobalPools = nodeSerializationEnv != null } sessionFactory = locator.createSessionFactory() session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, locator.ackBatchSize) diff --git a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt index ccc9a967db..d4db1bc540 100644 --- a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt +++ b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt @@ -5,7 +5,8 @@ import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions import net.corda.core.internal.div import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug @@ -89,15 +90,16 @@ class Verifier { } private fun initialiseSerialization() { - SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { - registerScheme(KryoVerifierSerializationScheme) - registerScheme(AMQPVerifierSerializationScheme) - } - /** - * Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming - * request received, see use of [context] in [main] method. - */ - SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + nodeSerializationEnv = SerializationEnvironmentImpl( + SerializationFactoryImpl().apply { + registerScheme(KryoVerifierSerializationScheme) + registerScheme(AMQPVerifierSerializationScheme) + }, + /** + * Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming + * request received, see use of [context] in [main] method. + */ + KRYO_P2P_CONTEXT) } }