CORDA-716 Make serialization init less static (#1996)

This commit is contained in:
Andrzej Cichocki 2017-11-10 15:44:43 +00:00 committed by GitHub
parent cc4c732a48
commit 052124bbe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 463 additions and 346 deletions

View File

@ -2430,19 +2430,13 @@ public static final class net.corda.core.serialization.SerializationContext$UseC
public static net.corda.core.serialization.SerializationContext$UseCase valueOf(String) public static net.corda.core.serialization.SerializationContext$UseCase valueOf(String)
public static net.corda.core.serialization.SerializationContext$UseCase[] values() public static net.corda.core.serialization.SerializationContext$UseCase[] values()
## ##
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object implements net.corda.core.serialization.internal.SerializationEnvironment public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getP2P_CONTEXT() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_CLIENT_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getRPC_SERVER_CONTEXT()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationFactory getSERIALIZATION_FACTORY()
@org.jetbrains.annotations.NotNull public net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT() @org.jetbrains.annotations.NotNull public final net.corda.core.serialization.SerializationContext getSTORAGE_CONTEXT()
public void setCHECKPOINT_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setP2P_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setRPC_CLIENT_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setRPC_SERVER_CONTEXT(net.corda.core.serialization.SerializationContext)
public void setSERIALIZATION_FACTORY(net.corda.core.serialization.SerializationFactory)
public void setSTORAGE_CONTEXT(net.corda.core.serialization.SerializationContext)
public static final net.corda.core.serialization.SerializationDefaults INSTANCE public static final net.corda.core.serialization.SerializationDefaults INSTANCE
## ##
public abstract class net.corda.core.serialization.SerializationFactory extends java.lang.Object public abstract class net.corda.core.serialization.SerializationFactory extends java.lang.Object

View File

@ -4,10 +4,10 @@ import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.internal.serialization.AMQP_RPC_CLIENT_CONTEXT
import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT
import java.time.Duration import java.time.Duration
@ -71,8 +71,15 @@ class CordaRPCClient @JvmOverloads constructor(
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT
) { ) {
init { init {
// TODO: allow clients to have serialization factory etc injected and align with RPC protocol version? try {
effectiveSerializationEnv
} catch (e: IllegalStateException) {
try {
KryoClientSerializationScheme.initialiseSerialization() KryoClientSerializationScheme.initialiseSerialization()
} catch (e: IllegalStateException) {
// Race e.g. two of these constructed in parallel, ignore.
}
}
} }
private val rpcClient = RPCClient<CordaRPCOps>( private val rpcClient = RPCClient<CordaRPCOps>(

View File

@ -2,15 +2,17 @@ package net.corda.client.rpc.internal
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT
import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT
import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.nodeapi.internal.serialization.kryo.AbstractKryoSerializationScheme import net.corda.nodeapi.internal.serialization.kryo.AbstractKryoSerializationScheme
import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.nodeapi.internal.serialization.kryo.RPCKryo import net.corda.nodeapi.internal.serialization.kryo.RPCKryo
import java.util.concurrent.atomic.AtomicBoolean
class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { class KryoClientSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
@ -29,25 +31,15 @@ class KryoClientSerializationScheme : AbstractKryoSerializationScheme() {
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
companion object { companion object {
val isInitialised = AtomicBoolean(false) /** Call from main only. */
fun initialiseSerialization() { fun initialiseSerialization() {
if (!isInitialised.compareAndSet(false, true)) return nodeSerializationEnv = SerializationEnvironmentImpl(
try { SerializationFactoryImpl().apply {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme()) registerScheme(KryoClientSerializationScheme())
registerScheme(AMQPClientSerializationScheme()) registerScheme(AMQPClientSerializationScheme())
} },
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT KRYO_P2P_CONTEXT,
SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT rpcClientContext = KRYO_RPC_CLIENT_CONTEXT)
} catch (e: IllegalStateException) {
// Check that it's registered as we expect
val factory = SerializationDefaults.SERIALIZATION_FACTORY
val checkedFactory = factory as? SerializationFactoryImpl
?: throw IllegalStateException("RPC client encountered conflicting configuration of serialization subsystem: $factory")
check(checkedFactory.alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) {
"RPC client encountered conflicting configuration of serialization subsystem."
}
}
} }
} }
} }

View File

@ -8,6 +8,7 @@ import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
@ -110,6 +111,7 @@ class RPCClient<I : RPCOps>(
maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis() maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis()
reconnectAttempts = rpcConfiguration.maxReconnectAttempts reconnectAttempts = rpcConfiguration.maxReconnectAttempts
minLargeMessageSize = rpcConfiguration.maxFileSize minLargeMessageSize = rpcConfiguration.maxFileSize
isUseGlobalPools = nodeSerializationEnv != null
} }
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext) val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext)

View File

@ -0,0 +1,66 @@
package net.corda.core.internal
import java.util.concurrent.atomic.AtomicReference
import kotlin.reflect.KProperty
/** May go from null to non-null and vice-versa, and that's it. */
abstract class ToggleField<T>(val name: String) {
private val writeMutex = Any() // Protects the toggle logic only.
abstract fun get(): T?
fun set(value: T?) = synchronized(writeMutex) {
if (value != null) {
check(get() == null) { "$name already has a value." }
setImpl(value)
} else {
check(get() != null) { "$name is already null." }
clear()
}
}
protected abstract fun setImpl(value: T)
protected abstract fun clear()
operator fun getValue(thisRef: Any?, property: KProperty<*>) = get()
operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T?) = set(value)
}
class SimpleToggleField<T>(name: String, private val once: Boolean = false) : ToggleField<T>(name) {
private val holder = AtomicReference<T?>() // Force T? in API for safety.
override fun get() = holder.get()
override fun setImpl(value: T) = holder.set(value)
override fun clear() {
check(!once) { "Value of $name cannot be changed." }
holder.set(null)
}
}
class ThreadLocalToggleField<T>(name: String) : ToggleField<T>(name) {
private val threadLocal = ThreadLocal<T?>()
override fun get() = threadLocal.get()
override fun setImpl(value: T) = threadLocal.set(value)
override fun clear() = threadLocal.remove()
}
/** The named thread has leaked from a previous test. */
class ThreadLeakException : RuntimeException("Leaked thread detected: ${Thread.currentThread().name}")
class InheritableThreadLocalToggleField<T>(name: String) : ToggleField<T>(name) {
private class Holder<T>(value: T) : AtomicReference<T?>(value) {
fun valueOrDeclareLeak() = get() ?: throw ThreadLeakException()
}
private val threadLocal = object : InheritableThreadLocal<Holder<T>?>() {
override fun childValue(holder: Holder<T>?): Holder<T>? {
// The Holder itself may be null due to prior events, a leak is not implied in that case:
holder?.valueOrDeclareLeak() // Fail fast.
return holder // What super does.
}
}
override fun get() = threadLocal.get()?.valueOrDeclareLeak()
override fun setImpl(value: T) = threadLocal.set(Holder(value))
override fun clear() = threadLocal.run {
val holder = get()!!
remove()
holder.set(null) // Threads that inherited the holder are now considered to have escaped from the test.
}
}

View File

@ -1,18 +0,0 @@
package net.corda.core.internal
import kotlin.reflect.KProperty
/**
* A write-once property to be used as delegate for Kotlin var properties. The expectation is that this is initialised
* prior to the spawning of any threads that may access it and so there's no need for it to be volatile.
*/
class WriteOnceProperty<T : Any>(private val defaultValue: T? = null) {
private var v: T? = defaultValue
operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: throw IllegalStateException("Write-once property $property not set.")
operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) {
check(v == defaultValue || v === value) { "Cannot set write-once property $property more than once." }
v = value
}
}

View File

@ -2,8 +2,7 @@ package net.corda.core.serialization
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256 import net.corda.core.crypto.sha256
import net.corda.core.internal.WriteOnceProperty import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.sequence import net.corda.core.utilities.sequence
@ -53,7 +52,7 @@ abstract class SerializationFactory {
* A context to use as a default if you do not require a specially configured context. It will be the current context * A context to use as a default if you do not require a specially configured context. It will be the current context
* if the use is somehow nested (see [currentContext]). * if the use is somehow nested (see [currentContext]).
*/ */
val defaultContext: SerializationContext get() = currentContext ?: SerializationDefaults.P2P_CONTEXT val defaultContext: SerializationContext get() = currentContext ?: effectiveSerializationEnv.p2pContext
private val _currentContext = ThreadLocal<SerializationContext?>() private val _currentContext = ThreadLocal<SerializationContext?>()
@ -90,7 +89,7 @@ abstract class SerializationFactory {
/** /**
* A default factory for serialization/deserialization, taking into account the [currentFactory] if set. * A default factory for serialization/deserialization, taking into account the [currentFactory] if set.
*/ */
val defaultFactory: SerializationFactory get() = currentFactory ?: SerializationDefaults.SERIALIZATION_FACTORY val defaultFactory: SerializationFactory get() = currentFactory ?: effectiveSerializationEnv.serializationFactory
/** /**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization, * If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
@ -173,13 +172,13 @@ interface SerializationContext {
/** /**
* Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client). * Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client).
*/ */
object SerializationDefaults : SerializationEnvironment { object SerializationDefaults {
override var SERIALIZATION_FACTORY: SerializationFactory by WriteOnceProperty() val SERIALIZATION_FACTORY get() = effectiveSerializationEnv.serializationFactory
override var P2P_CONTEXT: SerializationContext by WriteOnceProperty() val P2P_CONTEXT get() = effectiveSerializationEnv.p2pContext
override var RPC_SERVER_CONTEXT: SerializationContext by WriteOnceProperty() val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext
override var RPC_CLIENT_CONTEXT: SerializationContext by WriteOnceProperty() val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext
override var STORAGE_CONTEXT: SerializationContext by WriteOnceProperty() val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext
override var CHECKPOINT_CONTEXT: SerializationContext by WriteOnceProperty() val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
} }
/** /**

View File

@ -1,13 +1,55 @@
package net.corda.core.serialization.internal package net.corda.core.serialization.internal
import net.corda.core.internal.InheritableThreadLocalToggleField
import net.corda.core.internal.SimpleToggleField
import net.corda.core.internal.ThreadLocalToggleField
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializationFactory
interface SerializationEnvironment { interface SerializationEnvironment {
val SERIALIZATION_FACTORY: SerializationFactory val serializationFactory: SerializationFactory
val P2P_CONTEXT: SerializationContext val p2pContext: SerializationContext
val RPC_SERVER_CONTEXT: SerializationContext val rpcServerContext: SerializationContext
val RPC_CLIENT_CONTEXT: SerializationContext val rpcClientContext: SerializationContext
val STORAGE_CONTEXT: SerializationContext val storageContext: SerializationContext
val CHECKPOINT_CONTEXT: SerializationContext val checkpointContext: SerializationContext
} }
class SerializationEnvironmentImpl(
override val serializationFactory: SerializationFactory,
override val p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: SerializationContext? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: SerializationContext
init {
rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it }
}
}
private val _nodeSerializationEnv = SimpleToggleField<SerializationEnvironment>("nodeSerializationEnv", true)
@VisibleForTesting
val _globalSerializationEnv = SimpleToggleField<SerializationEnvironment>("globalSerializationEnv")
@VisibleForTesting
val _contextSerializationEnv = ThreadLocalToggleField<SerializationEnvironment>("contextSerializationEnv")
@VisibleForTesting
val _inheritableContextSerializationEnv = InheritableThreadLocalToggleField<SerializationEnvironment>("inheritableContextSerializationEnv")
private val serializationEnvProperties = listOf(_nodeSerializationEnv, _globalSerializationEnv, _contextSerializationEnv, _inheritableContextSerializationEnv)
val effectiveSerializationEnv: SerializationEnvironment
get() = serializationEnvProperties.map { Pair(it, it.get()) }.filter { it.second != null }.run {
singleOrNull()?.run {
second!!
} ?: throw IllegalStateException("Expected exactly 1 of {${serializationEnvProperties.joinToString(", ") { it.name }}} but got: {${joinToString(", ") { it.first.name }}}")
}
/** Should be set once in main. */
var nodeSerializationEnv by _nodeSerializationEnv

View File

@ -0,0 +1,125 @@
package net.corda.core.internal
import net.corda.core.internal.concurrent.fork
import net.corda.core.utilities.getOrThrow
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertNull
private fun <T> withSingleThreadExecutor(callable: ExecutorService.() -> T) = Executors.newSingleThreadExecutor().run {
try {
fork {}.getOrThrow() // Start the thread.
callable()
} finally {
shutdown()
while (!awaitTermination(1, TimeUnit.SECONDS)) {
// Do nothing.
}
}
}
class ToggleFieldTest {
@Test
fun `toggle is enforced`() {
listOf(SimpleToggleField<String>("simple"), ThreadLocalToggleField<String>("local"), InheritableThreadLocalToggleField("inheritable")).forEach { field ->
assertNull(field.get())
assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java)
field.set("hello")
assertEquals("hello", field.get())
assertThatThrownBy { field.set("world") }.isInstanceOf(IllegalStateException::class.java)
assertEquals("hello", field.get())
assertThatThrownBy { field.set("hello") }.isInstanceOf(IllegalStateException::class.java)
field.set(null)
assertNull(field.get())
}
}
@Test
fun `write-at-most-once field works`() {
val field = SimpleToggleField<String>("field", true)
assertNull(field.get())
assertThatThrownBy { field.set(null) }.isInstanceOf(IllegalStateException::class.java)
field.set("finalValue")
assertEquals("finalValue", field.get())
listOf("otherValue", "finalValue", null).forEach { value ->
assertThatThrownBy { field.set(value) }.isInstanceOf(IllegalStateException::class.java)
assertEquals("finalValue", field.get())
}
}
@Test
fun `thread local works`() {
val field = ThreadLocalToggleField<String>("field")
assertNull(field.get())
field.set("hello")
assertEquals("hello", field.get())
withSingleThreadExecutor {
assertNull(fork(field::get).getOrThrow())
}
field.set(null)
assertNull(field.get())
}
@Test
fun `inheritable thread local works`() {
val field = InheritableThreadLocalToggleField<String>("field")
assertNull(field.get())
field.set("hello")
assertEquals("hello", field.get())
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
}
field.set(null)
assertNull(field.get())
}
@Test
fun `existing threads do not inherit`() {
val field = InheritableThreadLocalToggleField<String>("field")
withSingleThreadExecutor {
field.set("hello")
assertEquals("hello", field.get())
assertNull(fork(field::get).getOrThrow())
}
}
@Test
fun `inherited values are poisoned on clear`() {
val field = InheritableThreadLocalToggleField<String>("field")
field.set("hello")
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
val threadName = fork { Thread.currentThread().name }.getOrThrow()
listOf(null, "world").forEach { value ->
field.set(value)
assertEquals(value, field.get())
val future = fork(field::get)
assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(ThreadLeakException::class.java)
.hasMessageContaining(threadName)
}
}
withSingleThreadExecutor {
assertEquals("world", fork(field::get).getOrThrow())
}
}
@Test
fun `leaked thread is detected as soon as it tries to create another`() {
val field = InheritableThreadLocalToggleField<String>("field")
field.set("hello")
withSingleThreadExecutor {
assertEquals("hello", fork(field::get).getOrThrow())
field.set(null) // The executor thread is now considered leaked.
val threadName = fork { Thread.currentThread().name }.getOrThrow()
val future = fork(::Thread)
assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(ThreadLeakException::class.java)
.hasMessageContaining(threadName)
}
}
}

View File

@ -1,6 +1,7 @@
package net.corda.nodeapi package net.corda.nodeapi
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.config.SSLConfiguration import net.corda.nodeapi.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.api.core.TransportConfiguration
@ -48,7 +49,8 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats. // It does not use AMQP messages for its own messages e.g. topology and heartbeats.
// TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications. // TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications.
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP" TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP",
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null)
) )
if (config != null && enableSSL) { if (config != null && enableSSL) {

View File

@ -25,7 +25,7 @@ public final class ForbiddenLambdaSerializationTests {
@Before @Before
public void setup() { public void setup() {
factory = testSerialization.env.getSERIALIZATION_FACTORY(); factory = testSerialization.getEnv().getSerializationFactory();
} }
@Test @Test

View File

@ -25,7 +25,7 @@ public final class LambdaCheckpointSerializationTest {
@Before @Before
public void setup() { public void setup() {
factory = testSerialization.env.getSERIALIZATION_FACTORY(); factory = testSerialization.getEnv().getSerializationFactory();
context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint); context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint);
} }

View File

@ -27,9 +27,8 @@ class ContractAttachmentSerializerTest {
@Before @Before
fun setup() { fun setup() {
factory = testSerialization.env.SERIALIZATION_FACTORY factory = testSerialization.env.serializationFactory
context = testSerialization.env.CHECKPOINT_CONTEXT context = testSerialization.env.checkpointContext
contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices)) contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices))
} }

View File

@ -8,7 +8,6 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.kryo.CordaKryo import net.corda.nodeapi.internal.serialization.kryo.CordaKryo
import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.kryo.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1 import net.corda.nodeapi.internal.serialization.kryo.KryoHeaderV0_1
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.rigorousMock import net.corda.testing.rigorousMock
import net.corda.testing.SerializationEnvironmentRule import net.corda.testing.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@ -26,8 +25,8 @@ class SerializationTokenTest {
@Before @Before
fun setup() { fun setup() {
factory = testSerialization.env.SERIALIZATION_FACTORY factory = testSerialization.env.serializationFactory
context = testSerialization.env.CHECKPOINT_CONTEXT.withWhitelisted(SingletonSerializationToken::class.java) context = testSerialization.env.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java)
} }
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized

View File

@ -17,26 +17,21 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.testing.*
import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_A
import net.corda.testing.DUMMY_NOTARY import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.SerializationEnvironmentRule
import net.corda.testing.driver.DriverDSLExposedInterface import net.corda.testing.driver.DriverDSLExposedInterface
import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.driver import net.corda.testing.driver.driver
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Before import org.junit.Before
import org.junit.Rule
import org.junit.Test import org.junit.Test
import java.net.URLClassLoader import java.net.URLClassLoader
import java.nio.file.Files import java.nio.file.Files
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
class AttachmentLoadingTests { class AttachmentLoadingTests {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private class Services : MockServices() { private class Services : MockServices() {
private val provider = CordappProviderImpl(CordappLoader.createDevMode(listOf(isolatedJAR)), attachments) private val provider = CordappProviderImpl(CordappLoader.createDevMode(listOf(isolatedJAR)), attachments)
private val cordapp get() = provider.cordapps.first() private val cordapp get() = provider.cordapps.first()
@ -83,7 +78,7 @@ class AttachmentLoadingTests {
} }
@Test @Test
fun `test a wire transaction has loaded the correct attachment`() { fun `test a wire transaction has loaded the correct attachment`() = withTestSerialization {
val appClassLoader = services.appContext.classLoader val appClassLoader = services.appContext.classLoader
val contractClass = appClassLoader.loadClass(ISOLATED_CONTRACT_ID).asSubclass(Contract::class.java) val contractClass = appClassLoader.loadClass(ISOLATED_CONTRACT_ID).asSubclass(Contract::class.java)
val generateInitialMethod = contractClass.getDeclaredMethod("generateInitial", PartyAndReference::class.java, Integer.TYPE, Party::class.java) val generateInitialMethod = contractClass.getDeclaredMethod("generateInitial", PartyAndReference::class.java, Integer.TYPE, Party::class.java)
@ -101,7 +96,7 @@ class AttachmentLoadingTests {
@Test @Test
fun `test that attachments retrieved over the network are not used for code`() { fun `test that attachments retrieved over the network are not used for code`() {
driver(initialiseSerialization = false) { driver {
installIsolatedCordappTo(bankAName) installIsolatedCordappTo(bankAName)
val (bankA, bankB) = createTwoNodes() val (bankA, bankB) = createTwoNodes()
assertFailsWith<UnexpectedFlowEndException>("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") { assertFailsWith<UnexpectedFlowEndException>("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") {
@ -112,7 +107,7 @@ class AttachmentLoadingTests {
@Test @Test
fun `tests that if the attachment is loaded on both sides already that a flow can run`() { fun `tests that if the attachment is loaded on both sides already that a flow can run`() {
driver(initialiseSerialization = false) { driver {
installIsolatedCordappTo(bankAName) installIsolatedCordappTo(bankAName)
installIsolatedCordappTo(bankBName) installIsolatedCordappTo(bankBName)
val (bankA, bankB) = createTwoNodes() val (bankA, bankB) = createTwoNodes()

View File

@ -10,11 +10,13 @@ import net.corda.nodeapi.NodeInfoFilesCopier
import net.corda.testing.ALICE import net.corda.testing.ALICE
import net.corda.testing.ALICE_KEY import net.corda.testing.ALICE_KEY
import net.corda.testing.getTestPartyAndCertificate import net.corda.testing.getTestPartyAndCertificate
import net.corda.testing.internal.NodeBasedTest import net.corda.testing.*
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.contentOf import org.assertj.core.api.Assertions.contentOf
import org.junit.Before import org.junit.Before
import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder
import rx.observers.TestSubscriber import rx.observers.TestSubscriber
import rx.schedulers.TestScheduler import rx.schedulers.TestScheduler
import java.nio.file.Path import java.nio.file.Path
@ -22,11 +24,17 @@ import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class NodeInfoWatcherTest : NodeBasedTest() { class NodeInfoWatcherTest {
companion object { companion object {
val nodeInfo = NodeInfo(listOf(), listOf(getTestPartyAndCertificate(ALICE)), 0, 0) val nodeInfo = NodeInfo(listOf(), listOf(getTestPartyAndCertificate(ALICE)), 0, 0)
} }
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
@Rule
@JvmField
val tempFolder = TemporaryFolder()
private lateinit var nodeInfoPath: Path private lateinit var nodeInfoPath: Path
private val scheduler = TestScheduler() private val scheduler = TestScheduler()
private val testSubscriber = TestSubscriber<NodeInfo>() private val testSubscriber = TestSubscriber<NodeInfo>()

View File

@ -9,9 +9,10 @@ import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
@ -25,7 +26,6 @@ import net.corda.node.services.messaging.NodeMessagingClient
import net.corda.node.utilities.AddressUtils import net.corda.node.utilities.AddressUtils
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.DemoClock import net.corda.node.utilities.DemoClock
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.ShutdownHook
import net.corda.nodeapi.internal.addShutdownHook import net.corda.nodeapi.internal.addShutdownHook
import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.*
@ -274,14 +274,15 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() { private fun initialiseSerialization() {
val classloader = cordappLoader.appClassLoader val classloader = cordappLoader.appClassLoader
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { nodeSerializationEnv = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoServerSerializationScheme()) registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPServerSerializationScheme()) registerScheme(AMQPServerSerializationScheme())
} },
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT.withClassLoader(classloader) KRYO_P2P_CONTEXT.withClassLoader(classloader),
SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader) rpcServerContext = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader),
SerializationDefaults.STORAGE_CONTEXT = KRYO_STORAGE_CONTEXT.withClassLoader(classloader) storageContext = KRYO_STORAGE_CONTEXT.withClassLoader(classloader),
SerializationDefaults.CHECKPOINT_CONTEXT = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader))
} }
/** Starts a blocking event loop for message dispatch. */ /** Starts a blocking event loop for message dispatch. */

View File

@ -11,6 +11,7 @@ import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.TransactionVerifierService import net.corda.core.node.services.TransactionVerifierService
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
@ -217,6 +218,7 @@ class NodeMessagingClient(override val config: NodeConfiguration,
locator.connectionTTL = -1 locator.connectionTTL = -1
locator.clientFailureCheckPeriod = -1 locator.clientFailureCheckPeriod = -1
locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE
locator.isUseGlobalPools = nodeSerializationEnv != null
sessionFactory = locator.createSessionFactory() sessionFactory = locator.createSessionFactory()
// Login using the node username. The broker will authentiate us as its node (as opposed to another peer) // Login using the node username. The broker will authentiate us as its node (as opposed to another peer)

View File

@ -60,7 +60,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(true)
private val realClock: Clock = Clock.systemUTC() private val realClock: Clock = Clock.systemUTC()
private val stoppedClock: Clock = Clock.fixed(realClock.instant(), realClock.zone) private val stoppedClock: Clock = Clock.fixed(realClock.instant(), realClock.zone)
private val testClock = TestClock(stoppedClock) private val testClock = TestClock(stoppedClock)

View File

@ -41,7 +41,7 @@ import kotlin.test.assertEquals
class HTTPNetworkMapClientTest { class HTTPNetworkMapClientTest {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(true)
private lateinit var server: Server private lateinit var server: Server
private lateinit var networkMapClient: NetworkMapClient private lateinit var networkMapClient: NetworkMapClient

View File

@ -30,7 +30,7 @@ class DistributedImmutableMapTests {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(true)
lateinit var cluster: List<Member> lateinit var cluster: List<Member>
lateinit var transaction: DatabaseTransaction lateinit var transaction: DatabaseTransaction
private val databases: MutableList<CordaPersistence> = mutableListOf() private val databases: MutableList<CordaPersistence> = mutableListOf()

View File

@ -1,9 +1,12 @@
package net.corda.node.services.vault package net.corda.node.services.vault
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.InsufficientBalanceException
import net.corda.core.contracts.LinearState import net.corda.core.contracts.LinearState
import net.corda.core.contracts.UniqueIdentifier import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.identity.AnonymousParty import net.corda.core.identity.AnonymousParty
import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose
import net.corda.core.internal.packageName import net.corda.core.internal.packageName
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultService import net.corda.core.node.services.VaultService
@ -11,6 +14,7 @@ import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.getOrThrow
import net.corda.finance.* import net.corda.finance.*
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.DUMMY_CASH_ISSUER import net.corda.finance.contracts.asset.DUMMY_CASH_ISSUER
@ -29,9 +33,9 @@ import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import java.util.* import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors import java.util.concurrent.Executors
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.fail
// TODO: Move this to the cash contract tests once mock services are further split up. // TODO: Move this to the cash contract tests once mock services are further split up.
@ -42,7 +46,7 @@ class VaultWithCashTest {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(true)
lateinit var services: MockServices lateinit var services: MockServices
lateinit var issuerServices: MockServices lateinit var issuerServices: MockServices
val vaultService: VaultService get() = services.vaultService val vaultService: VaultService get() = services.vaultService
@ -150,12 +154,9 @@ class VaultWithCashTest {
} }
val backgroundExecutor = Executors.newFixedThreadPool(2) val backgroundExecutor = Executors.newFixedThreadPool(2)
val countDown = CountDownLatch(2)
// 1st tx that spends our money. // 1st tx that spends our money.
backgroundExecutor.submit { val first = backgroundExecutor.fork {
database.transaction { database.transaction {
try {
val txn1Builder = TransactionBuilder(DUMMY_NOTARY) val txn1Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB) Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB)
val ptxn1 = notaryServices.signInitialTransaction(txn1Builder) val ptxn1 = notaryServices.signInitialTransaction(txn1Builder)
@ -180,18 +181,13 @@ class VaultWithCashTest {
LOCKED: ${lockedStates2.count()} : $lockedStates2 LOCKED: ${lockedStates2.count()} : $lockedStates2
""") """)
txn1 txn1
} catch (e: Exception) {
println(e)
}
} }
println("txn1 COMMITTED!") println("txn1 COMMITTED!")
countDown.countDown()
} }
// 2nd tx that attempts to spend same money // 2nd tx that attempts to spend same money
backgroundExecutor.submit { val second = backgroundExecutor.fork {
database.transaction { database.transaction {
try {
val txn2Builder = TransactionBuilder(DUMMY_NOTARY) val txn2Builder = TransactionBuilder(DUMMY_NOTARY)
Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB) Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB)
val ptxn2 = notaryServices.signInitialTransaction(txn2Builder) val ptxn2 = notaryServices.signInitialTransaction(txn2Builder)
@ -216,16 +212,16 @@ class VaultWithCashTest {
LOCKED: ${lockedStates2.count()} : $lockedStates2 LOCKED: ${lockedStates2.count()} : $lockedStates2
""") """)
txn2 txn2
} catch (e: Exception) {
println(e)
}
} }
println("txn2 COMMITTED!") println("txn2 COMMITTED!")
countDown.countDown()
} }
val both = listOf(first, second).transpose()
countDown.await() try {
both.getOrThrow()
fail("Expected insufficient balance.")
} catch (e: InsufficientBalanceException) {
assertEquals(0, e.suppressed.size) // One should succeed.
}
database.transaction { database.transaction {
println("Cash balance: ${services.getCashBalance(USD)}") println("Cash balance: ${services.getCashBalance(USD)}")
assertThat(services.getCashBalance(USD)).isIn(DOLLARS(20), DOLLARS(40)) assertThat(services.getCashBalance(USD)).isIn(DOLLARS(20), DOLLARS(40))

View File

@ -203,8 +203,8 @@ class NodeInterestRatesTest {
} }
@Test @Test
fun `network tearoff`() { fun `network tearoff`() = withoutTestSerialization {
val mockNet = MockNetwork(initialiseSerialization = false, cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs")) val mockNet = MockNetwork(cordappPackages = listOf("net.corda.finance.contracts", "net.corda.irs"))
val aliceNode = mockNet.createPartyNode(ALICE.name) val aliceNode = mockNet.createPartyNode(ALICE.name)
val oracleNode = mockNet.createNode().apply { val oracleNode = mockNet.createNode().apply {
internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) internals.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java)

View File

@ -37,6 +37,7 @@ import net.corda.nodeapi.internal.addShutdownHook
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.common.internal.NetworkParametersCopier import net.corda.testing.common.internal.NetworkParametersCopier
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.setGlobalSerialization
import net.corda.testing.internal.ProcessUtilities import net.corda.testing.internal.ProcessUtilities
import net.corda.testing.node.ClusterSpec import net.corda.testing.node.ClusterSpec
import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO
@ -413,7 +414,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
coerce: (D) -> DI, coerce: (D) -> DI,
dsl: DI.() -> A dsl: DI.() -> A
): A { ): A {
val serializationEnv = initialiseTestSerialization(initialiseSerialization) val serializationEnv = setGlobalSerialization(initialiseSerialization)
val shutdownHook = addShutdownHook(driverDsl::shutdown) val shutdownHook = addShutdownHook(driverDsl::shutdown)
try { try {
driverDsl.start() driverDsl.start()
@ -424,7 +425,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
} finally { } finally {
driverDsl.shutdown() driverDsl.shutdown()
shutdownHook.cancel() shutdownHook.cancel()
serializationEnv.resetTestSerialization() serializationEnv.unset()
} }
} }
@ -451,7 +452,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
driverDslWrapper: (DriverDSL) -> D, driverDslWrapper: (DriverDSL) -> D,
coerce: (D) -> DI, dsl: DI.() -> A coerce: (D) -> DI, dsl: DI.() -> A
): A { ): A {
val serializationEnv = initialiseTestSerialization(initialiseSerialization) val serializationEnv = setGlobalSerialization(initialiseSerialization)
val driverDsl = driverDslWrapper( val driverDsl = driverDslWrapper(
DriverDSL( DriverDSL(
portAllocation = portAllocation, portAllocation = portAllocation,
@ -475,7 +476,7 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
} finally { } finally {
driverDsl.shutdown() driverDsl.shutdown()
shutdownHook.cancel() shutdownHook.cancel()
serializationEnv.resetTestSerialization() serializationEnv.unset()
} }
} }

View File

@ -38,7 +38,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testSerialization = SerializationEnvironmentRule(true)
@Rule @Rule
@JvmField @JvmField
val tempFolder = TemporaryFolder() val tempFolder = TemporaryFolder()
@ -63,6 +63,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
@After @After
fun stopAllNodes() { fun stopAllNodes() {
val shutdownExecutor = Executors.newScheduledThreadPool(nodes.size) val shutdownExecutor = Executors.newScheduledThreadPool(nodes.size)
try {
nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow() nodes.map { shutdownExecutor.fork(it::dispose) }.transpose().getOrThrow()
// Wait until ports are released // Wait until ports are released
val portNotBoundChecks = nodes.flatMap { val portNotBoundChecks = nodes.flatMap {
@ -73,6 +74,9 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
}.filterNotNull() }.filterNotNull()
nodes.clear() nodes.clear()
portNotBoundChecks.transpose().getOrThrow() portNotBoundChecks.transpose().getOrThrow()
} finally {
shutdownExecutor.shutdown()
}
} }
@JvmOverloads @JvmOverloads

View File

@ -39,7 +39,7 @@ import net.corda.node.utilities.ServiceIdentityGenerator
import net.corda.testing.DUMMY_NOTARY import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.common.internal.NetworkParametersCopier import net.corda.testing.common.internal.NetworkParametersCopier
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.initialiseTestSerialization import net.corda.testing.setGlobalSerialization
import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.testNodeConfiguration import net.corda.testing.testNodeConfiguration
@ -136,9 +136,8 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete
private val networkId = random63BitValue() private val networkId = random63BitValue()
private val networkParameters: NetworkParametersCopier private val networkParameters: NetworkParametersCopier
private val _nodes = mutableListOf<MockNode>() private val _nodes = mutableListOf<MockNode>()
private val serializationEnv = initialiseTestSerialization(initialiseSerialization) private val serializationEnv = setGlobalSerialization(initialiseSerialization)
private val sharedUserCount = AtomicInteger(0) private val sharedUserCount = AtomicInteger(0)
/** A read only view of the current set of executing nodes. */ /** A read only view of the current set of executing nodes. */
val nodes: List<MockNode> get() = _nodes val nodes: List<MockNode> get() = _nodes
@ -419,7 +418,7 @@ class MockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParamete
fun stopNodes() { fun stopNodes() {
nodes.forEach { it.started?.dispose() } nodes.forEach { it.started?.dispose() }
serializationEnv.resetTestSerialization() serializationEnv.unset()
} }
// Test method to block until all scheduled activity, active flows // Test method to block until all scheduled activity, active flows

View File

@ -3,10 +3,7 @@ package net.corda.testing
import com.nhaarman.mockito_kotlin.doNothing import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.KryoClientSerializationScheme import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.core.crypto.SecureHash import net.corda.core.serialization.internal.*
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.utilities.ByteSequence
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.nodeapi.internal.serialization.amqp.AMQPClientSerializationScheme
@ -15,183 +12,84 @@ import org.junit.rules.TestRule
import org.junit.runner.Description import org.junit.runner.Description
import org.junit.runners.model.Statement import org.junit.runners.model.Statement
class SerializationEnvironmentRule : TestRule { /** @param inheritable whether new threads inherit the environment, use sparingly. */
lateinit var env: SerializationEnvironment class SerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
val env: SerializationEnvironment = createTestSerializationEnv()
override fun apply(base: Statement, description: Description?) = object : Statement() { override fun apply(base: Statement, description: Description?) = object : Statement() {
override fun evaluate() = withTestSerialization { override fun evaluate() = env.asContextEnv(inheritable) {
env = it
base.evaluate() base.evaluate()
} }
} }
} }
interface TestSerializationEnvironment : SerializationEnvironment { interface GlobalSerializationEnvironment : SerializationEnvironment {
fun resetTestSerialization() /** Unset this environment. */
fun unset()
} }
fun <T> withTestSerialization(block: (SerializationEnvironment) -> T): T { /** @param inheritable whether new threads inherit the environment, use sparingly. */
val env = initialiseTestSerializationImpl() fun <T> withTestSerialization(inheritable: Boolean = false, callable: (SerializationEnvironment) -> T): T {
return createTestSerializationEnv().asContextEnv(inheritable, callable)
}
private fun <T> SerializationEnvironment.asContextEnv(inheritable: Boolean, callable: (SerializationEnvironment) -> T): T {
val property = if (inheritable) _inheritableContextSerializationEnv else _contextSerializationEnv
property.set(this)
try { try {
return block(env) return callable(this)
} finally { } finally {
env.resetTestSerialization() property.set(null)
} }
} }
/** @param armed true to init, false to do nothing and return a dummy env. */ /**
fun initialiseTestSerialization(armed: Boolean): TestSerializationEnvironment { * For example your test class uses [SerializationEnvironmentRule] but you want to turn it off for one method.
* Use sparingly, ideally a test class shouldn't mix serialization init mechanisms.
*/
fun <T> withoutTestSerialization(callable: () -> T): T {
val (property, env) = listOf(_contextSerializationEnv, _inheritableContextSerializationEnv).map { Pair(it, it.get()) }.single { it.second != null }
property.set(null)
try {
return callable()
} finally {
property.set(env)
}
}
/**
* Should only be used by Driver and MockNode.
* @param armed true to install, false to do nothing and return a dummy env.
*/
fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
return if (armed) { return if (armed) {
val env = initialiseTestSerializationImpl() object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() {
object : TestSerializationEnvironment, SerializationEnvironment by env { override fun unset() {
override fun resetTestSerialization() = env.resetTestSerialization() _globalSerializationEnv.set(null)
}
}.also {
_globalSerializationEnv.set(it)
} }
} else { } else {
rigorousMock<TestSerializationEnvironment>().also { rigorousMock<GlobalSerializationEnvironment>().also {
doNothing().whenever(it).resetTestSerialization() doNothing().whenever(it).unset()
} }
} }
} }
private fun initialiseTestSerializationImpl() = SerializationDefaults.apply { private fun createTestSerializationEnv() = SerializationEnvironmentImpl(
// Stop the CordaRPCClient from trying to setup the defaults as we're about to do it now SerializationFactoryImpl().apply {
KryoClientSerializationScheme.isInitialised.set(true)
// Check that everything is configured for testing with mutable delegating instances.
try {
check(SERIALIZATION_FACTORY is TestSerializationFactory)
} catch (e: IllegalStateException) {
SERIALIZATION_FACTORY = TestSerializationFactory()
}
try {
check(P2P_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
P2P_CONTEXT = TestSerializationContext()
}
try {
check(RPC_SERVER_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
RPC_SERVER_CONTEXT = TestSerializationContext()
}
try {
check(RPC_CLIENT_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
RPC_CLIENT_CONTEXT = TestSerializationContext()
}
try {
check(STORAGE_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
STORAGE_CONTEXT = TestSerializationContext()
}
try {
check(CHECKPOINT_CONTEXT is TestSerializationContext)
} catch (e: IllegalStateException) {
CHECKPOINT_CONTEXT = TestSerializationContext()
}
// Check that the previous test, if there was one, cleaned up after itself.
// IF YOU SEE THESE MESSAGES, THEN IT MEANS A TEST HAS NOT CALLED resetTestSerialization()
check((SERIALIZATION_FACTORY as TestSerializationFactory).delegate == null, { "Expected uninitialised serialization framework but found it set from: $SERIALIZATION_FACTORY" })
check((P2P_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $P2P_CONTEXT" })
check((RPC_SERVER_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_SERVER_CONTEXT" })
check((RPC_CLIENT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $RPC_CLIENT_CONTEXT" })
check((STORAGE_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $STORAGE_CONTEXT" })
check((CHECKPOINT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: $CHECKPOINT_CONTEXT" })
// Now configure all the testing related delegates.
(SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme()) registerScheme(KryoClientSerializationScheme())
registerScheme(KryoServerSerializationScheme()) registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPClientSerializationScheme()) registerScheme(AMQPClientSerializationScheme())
registerScheme(AMQPServerSerializationScheme()) registerScheme(AMQPServerSerializationScheme())
} },
if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT,
(P2P_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_P2P_CONTEXT else KRYO_P2P_CONTEXT KRYO_RPC_SERVER_CONTEXT,
(RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT KRYO_RPC_CLIENT_CONTEXT,
(RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_CLIENT_CONTEXT if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT,
(STORAGE_CONTEXT as TestSerializationContext).delegate = if (isAmqpEnabled()) AMQP_STORAGE_CONTEXT else KRYO_STORAGE_CONTEXT KRYO_CHECKPOINT_CONTEXT)
(CHECKPOINT_CONTEXT as TestSerializationContext).delegate = KRYO_CHECKPOINT_CONTEXT
}
private const val AMQP_ENABLE_PROP_NAME = "net.corda.testing.amqp.enable" private const val AMQP_ENABLE_PROP_NAME = "net.corda.testing.amqp.enable"
// TODO: Remove usages of this function when we fully switched to AMQP // TODO: Remove usages of this function when we fully switched to AMQP
private fun isAmqpEnabled(): Boolean = java.lang.Boolean.getBoolean(AMQP_ENABLE_PROP_NAME) private fun isAmqpEnabled(): Boolean = java.lang.Boolean.getBoolean(AMQP_ENABLE_PROP_NAME)
private fun SerializationDefaults.resetTestSerialization() {
(SERIALIZATION_FACTORY as TestSerializationFactory).delegate = null
(P2P_CONTEXT as TestSerializationContext).delegate = null
(RPC_SERVER_CONTEXT as TestSerializationContext).delegate = null
(RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = null
(STORAGE_CONTEXT as TestSerializationContext).delegate = null
(CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null
}
class TestSerializationFactory : SerializationFactory() {
var delegate: SerializationFactory? = null
set(value) {
field = value
stackTrace = Exception().stackTrace.asList()
}
private var stackTrace: List<StackTraceElement>? = null
override fun toString(): String = stackTrace?.joinToString("\n") ?: "null"
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
return delegate!!.deserialize(byteSequence, clazz, context)
}
override fun <T : Any> deserializeWithCompatibleContext(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): ObjectWithCompatibleContext<T> {
return delegate!!.deserializeWithCompatibleContext(byteSequence, clazz, context)
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
return delegate!!.serialize(obj, context)
}
}
class TestSerializationContext : SerializationContext {
var delegate: SerializationContext? = null
set(value) {
field = value
stackTrace = Exception().stackTrace.asList()
}
private var stackTrace: List<StackTraceElement>? = null
override fun toString(): String = stackTrace?.joinToString("\n") ?: "null"
override val preferredSerializationVersion: ByteSequence
get() = delegate!!.preferredSerializationVersion
override val deserializationClassLoader: ClassLoader
get() = delegate!!.deserializationClassLoader
override val whitelist: ClassWhitelist
get() = delegate!!.whitelist
override val properties: Map<Any, Any>
get() = delegate!!.properties
override val objectReferencesEnabled: Boolean
get() = delegate!!.objectReferencesEnabled
override val useCase: SerializationContext.UseCase
get() = delegate!!.useCase
override fun withProperty(property: Any, value: Any): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withProperty(property, value) }
}
override fun withoutReferences(): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withoutReferences() }
}
override fun withClassLoader(classLoader: ClassLoader): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withClassLoader(classLoader) }
}
override fun withWhitelisted(clazz: Class<*>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) }
}
override fun withPreferredSerializationVersion(versionHeader: VersionHeader): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) }
}
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withAttachmentsClassLoader(attachmentHashes) }
}
}

View File

@ -1,6 +1,7 @@
package net.corda.testing.messaging package net.corda.testing.messaging
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisMessagingComponent import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ArtemisTcpTransport
@ -27,6 +28,7 @@ class SimpleMQClient(val target: NetworkHostAndPort,
val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
isBlockOnNonDurableSend = true isBlockOnNonDurableSend = true
threadPoolMaxSize = 1 threadPoolMaxSize = 1
isUseGlobalPools = nodeSerializationEnv != null
} }
sessionFactory = locator.createSessionFactory() sessionFactory = locator.createSessionFactory()
session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, locator.ackBatchSize) session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, locator.ackBatchSize)

View File

@ -5,7 +5,8 @@ import com.typesafe.config.ConfigFactory
import com.typesafe.config.ConfigParseOptions import com.typesafe.config.ConfigParseOptions
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -89,15 +90,16 @@ class Verifier {
} }
private fun initialiseSerialization() { private fun initialiseSerialization() {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { nodeSerializationEnv = SerializationEnvironmentImpl(
SerializationFactoryImpl().apply {
registerScheme(KryoVerifierSerializationScheme) registerScheme(KryoVerifierSerializationScheme)
registerScheme(AMQPVerifierSerializationScheme) registerScheme(AMQPVerifierSerializationScheme)
} },
/** /**
* Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming * Even though default context is set to Kryo P2P, the encoding will be adjusted depending on the incoming
* request received, see use of [context] in [main] method. * request received, see use of [context] in [main] method.
*/ */
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT KRYO_P2P_CONTEXT)
} }
} }