diff --git a/client/jackson/build.gradle b/client/jackson/build.gradle index 1457f01bb4..ed76eb3e89 100644 --- a/client/jackson/build.gradle +++ b/client/jackson/build.gradle @@ -6,6 +6,8 @@ apply plugin: 'com.jfrog.artifactory' dependencies { compile project(':core') compile project(':finance') + testCompile project(':test-utils') + compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" diff --git a/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt b/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt index b1bbfa8f67..12a226f7b0 100644 --- a/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt +++ b/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt @@ -7,6 +7,7 @@ import com.pholser.junit.quickcheck.runner.JUnitQuickcheck import net.corda.core.contracts.Amount import net.corda.core.contracts.USD import net.corda.core.testing.PublicKeyGenerator +import net.corda.testing.TestDependencyInjectionBase import net.i2p.crypto.eddsa.EdDSAPublicKey import org.junit.Test import org.junit.runner.RunWith @@ -15,7 +16,7 @@ import java.util.* import kotlin.test.assertEquals @RunWith(JUnitQuickcheck::class) -class JacksonSupportTest { +class JacksonSupportTest : TestDependencyInjectionBase() { companion object { val mapper = JacksonSupport.createNonRpcMapper() } diff --git a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt index 87a3c4da54..29617faad5 100644 --- a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt +++ b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt @@ -51,7 +51,7 @@ class NodeMonitorModelTest : DriverBasedTest() { lateinit var networkMapUpdates: Observable lateinit var newNode: (X500Name) -> NodeInfo - override fun setup() = driver { + override fun setup() = driver(initialiseSerialization = false) { val cashUser = User("user1", "test", permissions = setOf( startFlowPermission(), startFlowPermission(), @@ -72,14 +72,14 @@ class NodeMonitorModelTest : DriverBasedTest() { vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() - monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) + monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false) rpc = monitor.proxyObservable.value!! val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow() bobNode = bobNodeHandle.nodeInfo val monitorBob = NodeMonitorModel() stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed() - monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) + monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false) rpcBob = monitorBob.proxyObservable.value!! runTest() } diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt index 944c49643c..e952a524a6 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt @@ -51,12 +51,13 @@ class NodeMonitorModel { * Register for updates to/from a given vault. * TODO provide an unsubscribe mechanism */ - fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String) { + fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String, initialiseSerialization: Boolean = true) { val client = CordaRPCClient( hostAndPort = nodeHostAndPort, configuration = CordaRPCClientConfiguration.default.copy( connectionMaxRetryInterval = 10.seconds - ) + ), + initialiseSerialization = initialiseSerialization ) val connection = client.start(username, password) val proxy = connection.proxy diff --git a/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java b/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java index a09af59e85..626890b4ae 100644 --- a/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java +++ b/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java @@ -1,30 +1,38 @@ package net.corda.client.rpc; -import com.google.common.util.concurrent.*; -import net.corda.client.rpc.internal.*; -import net.corda.contracts.asset.*; -import net.corda.core.contracts.*; -import net.corda.core.messaging.*; -import net.corda.core.node.services.*; -import net.corda.core.node.services.vault.*; -import net.corda.core.utilities.*; -import net.corda.flows.*; -import net.corda.node.internal.*; -import net.corda.node.services.transactions.*; -import net.corda.nodeapi.*; -import net.corda.schemas.*; -import net.corda.testing.node.*; -import org.junit.*; +import com.google.common.util.concurrent.ListenableFuture; +import net.corda.client.rpc.internal.RPCClient; +import net.corda.contracts.asset.Cash; +import net.corda.core.contracts.Amount; +import net.corda.core.messaging.CordaRPCOps; +import net.corda.core.messaging.FlowHandle; +import net.corda.core.node.services.ServiceInfo; +import net.corda.core.node.services.Vault; +import net.corda.core.node.services.vault.Builder; +import net.corda.core.node.services.vault.QueryCriteria; +import net.corda.core.utilities.OpaqueBytes; +import net.corda.flows.AbstractCashFlow; +import net.corda.flows.CashIssueFlow; +import net.corda.flows.CashPaymentFlow; +import net.corda.node.internal.Node; +import net.corda.node.services.transactions.ValidatingNotaryService; +import net.corda.nodeapi.User; +import net.corda.schemas.CashSchemaV1; +import net.corda.testing.node.NodeBasedTest; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; -import java.io.*; -import java.lang.reflect.*; +import java.io.IOException; +import java.lang.reflect.Field; import java.util.*; -import java.util.concurrent.*; +import java.util.concurrent.ExecutionException; -import static kotlin.test.AssertionsKt.*; -import static net.corda.client.rpc.CordaRPCClientConfiguration.*; -import static net.corda.node.services.RPCUserServiceKt.*; -import static net.corda.testing.TestConstants.*; +import static kotlin.test.AssertionsKt.assertEquals; +import static net.corda.client.rpc.CordaRPCClientConfiguration.getDefault; +import static net.corda.node.services.RPCUserServiceKt.startFlowPermission; +import static net.corda.testing.TestConstants.getALICE; public class CordaRPCJavaClientTest extends NodeBasedTest { private List perms = Arrays.asList(startFlowPermission(CashPaymentFlow.class), startFlowPermission(CashIssueFlow.class)); @@ -46,7 +54,7 @@ public class CordaRPCJavaClientTest extends NodeBasedTest { Set services = new HashSet<>(Collections.singletonList(new ServiceInfo(ValidatingNotaryService.Companion.getType(), null))); ListenableFuture nodeFuture = startNode(getALICE().getName(), 1, services, Arrays.asList(rpcUser), Collections.emptyMap()); node = nodeFuture.get(); - client = new CordaRPCClient(node.getConfiguration().getRpcAddress(), null, getDefault()); + client = new CordaRPCClient(node.getConfiguration().getRpcAddress(), null, getDefault(), false); } @After diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt index 2aa147d3a9..aba7ddd5f7 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt @@ -49,7 +49,7 @@ class CordaRPCClientTest : NodeBasedTest() { @Before fun setUp() { node = startNode(ALICE.name, rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow() - client = CordaRPCClient(node.configuration.rpcAddress!!) + client = CordaRPCClient(node.configuration.rpcAddress!!, initialiseSerialization = false) } @After diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index bf20920824..51455b3062 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -1,10 +1,5 @@ package net.corda.client.rpc -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.util.concurrent.Futures import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration @@ -14,11 +9,11 @@ import net.corda.core.getOrThrow import net.corda.core.messaging.RPCOps import net.corda.core.utilities.millis import net.corda.core.utilities.seconds +import net.corda.core.serialization.SerializationDefaults import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.Try import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.nodeapi.RPCApi -import net.corda.nodeapi.RPCKryo import net.corda.testing.* import net.corda.testing.driver.poll import org.apache.activemq.artemis.api.core.SimpleString @@ -305,16 +300,8 @@ class RPCStabilityTests { return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk } } } - val dummyObservableSerialiser = object : Serializer>() { - override fun write(kryo: Kryo?, output: Output?, `object`: Observable?) { - } - override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { - return Observable.empty() - } - } @Test fun `slow consumers are kicked`() { - val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build() rpcDriver { val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get() @@ -339,7 +326,7 @@ class RPCStabilityTests { methodName = SlowConsumerRPCOps::streamAtInterval.name, arguments = listOf(10.millis, 123456) ) - request.writeToClientMessage(kryoPool, message) + request.writeToClientMessage(SerializationDefaults.RPC_SERVER_CONTEXT, message) producer.send(message) session.commit() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 78baa8d906..584b8c7cd1 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -2,11 +2,16 @@ package net.corda.client.rpc import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.client.rpc.serialization.KryoClientSerializationScheme import net.corda.core.messaging.CordaRPCOps +import net.corda.core.serialization.SerializationDefaults import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.config.SSLConfiguration +import net.corda.nodeapi.serialization.KRYO_P2P_CONTEXT +import net.corda.nodeapi.serialization.KRYO_RPC_CLIENT_CONTEXT +import net.corda.nodeapi.serialization.SerializationFactoryImpl import java.time.Duration /** @see RPCClient.RPCConnection */ @@ -35,11 +40,22 @@ data class CordaRPCClientConfiguration( class CordaRPCClient( hostAndPort: NetworkHostAndPort, sslConfiguration: SSLConfiguration? = null, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default, + initialiseSerialization: Boolean = true ) { + init { + // Init serialization. It's plausible there are multiple clients in a single JVM, so be tolerant of + // others having registered first. + // TODO: allow clients to have serialization factory etc injected and align with RPC protocol version? + if (initialiseSerialization) { + initialiseSerialization() + } + } + private val rpcClient = RPCClient( tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), - configuration.toRpcClientConfiguration() + configuration.toRpcClientConfiguration(), + KRYO_RPC_CLIENT_CONTEXT ) fun start(username: String, password: String): CordaRPCConnection { @@ -49,4 +65,20 @@ class CordaRPCClient( inline fun use(username: String, password: String, block: (CordaRPCConnection) -> A): A { return start(username, password).use(block) } + + companion object { + fun initialiseSerialization() { + try { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme()) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT + } catch(e: IllegalStateException) { + // Check that it's registered as we expect + check(SerializationDefaults.SERIALIZATION_FACTORY is SerializationFactoryImpl) { "RPC client encountered conflicting configuration of serialization subsystem." } + check((SerializationDefaults.SERIALIZATION_FACTORY as SerializationFactoryImpl).alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) { "RPC client encountered conflicting configuration of serialization subsystem." } + } + } + } } \ No newline at end of file diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index 059f5d2be7..54a9964d92 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -1,9 +1,11 @@ package net.corda.client.rpc.internal +import net.corda.core.crypto.random63BitValue import net.corda.core.logElapsedTime import net.corda.core.messaging.RPCOps +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults import net.corda.core.utilities.minutes -import net.corda.core.crypto.random63BitValue import net.corda.core.utilities.seconds import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor @@ -85,13 +87,15 @@ data class RPCClientConfiguration( */ class RPCClient( val transport: TransportConfiguration, - val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default + val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default, + val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT ) { constructor( hostAndPort: NetworkHostAndPort, sslConfiguration: SSLConfiguration? = null, - configuration: RPCClientConfiguration = RPCClientConfiguration.default - ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration) + configuration: RPCClientConfiguration = RPCClientConfiguration.default, + serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration, serializationContext) companion object { private val log = loggerFor>() @@ -146,7 +150,7 @@ class RPCClient( minLargeMessageSize = rpcConfiguration.maxFileSize } - val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) + val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext) try { proxyHandler.start() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index e0cb142d82..03cc1bec69 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -4,7 +4,6 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.cache.Cache import com.google.common.cache.CacheBuilder import com.google.common.cache.RemovalCause @@ -18,7 +17,7 @@ import net.corda.core.internal.LazyPool import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle import net.corda.core.messaging.RPCOps -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.* import net.corda.nodeapi.* import org.apache.activemq.artemis.api.core.SimpleString @@ -64,7 +63,8 @@ class RPCClientProxyHandler( private val rpcPassword: String, private val serverLocator: ServerLocator, private val clientAddress: SimpleString, - private val rpcOpsClass: Class + private val rpcOpsClass: Class, + serializationContext: SerializationContext ) : InvocationHandler { private enum class State { @@ -77,9 +77,6 @@ class RPCClientProxyHandler( private companion object { val log = loggerFor() - // Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context - // to do that. However it may still be used for serialisation of RPC requests and related messages. - val kryoPool: KryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build() // To check whether toString() is being invoked val toStringMethod: Method = Object::toString.javaMethod!! } @@ -112,8 +109,7 @@ class RPCClientProxyHandler( private val observablesToReap = ThreadBox(object { var observables = ArrayList() }) - // A Kryo pool that automatically adds the observable context when an instance is requested. - private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext) + private val serializationContextWithObservableContext = RpcClientObservableSerializer.createContext(serializationContext, observableContext) private fun createRpcObservableMap(): RpcObservableMap { val onObservableRemove = RemovalListener>> { @@ -197,7 +193,7 @@ class RPCClientProxyHandler( val replyFuture = SettableFuture.create() sessionAndProducerPool.run { val message = it.session.createMessage(false) - request.writeToClientMessage(kryoPool, message) + request.writeToClientMessage(serializationContextWithObservableContext, message) log.debug { val argumentsString = arguments?.joinToString() ?: "" @@ -224,7 +220,7 @@ class RPCClientProxyHandler( // The handler for Artemis messages. private fun artemisMessageHandler(message: ClientMessage) { - val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message) + val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message) log.debug { "Got message from RPC server $serverToClient" } when (serverToClient) { is RPCApi.ServerToClient.RpcReply -> { @@ -351,7 +347,7 @@ private typealias CallSiteMap = ConcurrentHashMap * @param observableMap holds the Observables that are ultimately exposed to the user. * @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to. */ -private data class ObservableContext( +data class ObservableContext( val callSiteMap: CallSiteMap?, val observableMap: RpcObservableMap, val hardReferenceStore: MutableSet> @@ -360,10 +356,11 @@ private data class ObservableContext( /** * A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext]. */ -private object RpcClientObservableSerializer : Serializer>() { +object RpcClientObservableSerializer : Serializer>() { private object RpcObservableContextKey - fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { - return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + + fun createContext(serializationContext: SerializationContext, observableContext: ObservableContext): SerializationContext { + return serializationContext.withProperty(RpcObservableContextKey, observableContext) } override fun read(kryo: Kryo, input: Input, type: Class>): Observable { diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..53b70c3b7a --- /dev/null +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt @@ -0,0 +1,27 @@ +package net.corda.client.rpc.serialization + +import com.esotericsoftware.kryo.pool.KryoPool +import net.corda.client.rpc.internal.RpcClientObservableSerializer +import net.corda.core.serialization.DefaultKryoCustomizer +import net.corda.core.serialization.SerializationContext +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.RPCKryo +import net.corda.nodeapi.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.serialization.KryoHeaderV0_1 + +class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P) + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + return KryoPool.Builder { + DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } + }.build() + } + + // We're on the client and don't have access to server classes. + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } +} \ No newline at end of file diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt index 0117504c2e..6edb347ed1 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt @@ -27,7 +27,9 @@ import kotlin.test.assertTrue class ClientRPCInfrastructureTests : AbstractRPCTest() { // TODO: Test that timeouts work - private fun RPCDriverExposedDSLInterface.testProxy() = testProxy(TestOpsImpl()).ops + private fun RPCDriverExposedDSLInterface.testProxy(): TestOps { + return testProxy(TestOpsImpl()).ops + } interface TestOps : RPCOps { @Throws(IllegalArgumentException::class) diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt index ebc9cef461..f31469bcb4 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt @@ -1,8 +1,8 @@ package net.corda.client.rpc import net.corda.core.messaging.RPCOps -import net.corda.node.services.messaging.requirePermission import net.corda.node.services.messaging.getRpcContext +import net.corda.node.services.messaging.requirePermission import net.corda.nodeapi.PermissionException import net.corda.nodeapi.User import net.corda.testing.RPCDriverExposedDSLInterface diff --git a/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt b/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt index edcf018e82..24735f8c24 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt @@ -1,8 +1,8 @@ package net.corda.core.crypto import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.opaque import net.corda.core.serialization.serialize +import net.corda.core.utilities.sequence import java.security.PublicKey import java.time.Instant import java.util.* @@ -51,7 +51,7 @@ open class MetaData( if (timestamp != other.timestamp) return false if (visibleInputs != other.visibleInputs) return false if (signedInputs != other.signedInputs) return false - if (merkleRoot.opaque() != other.merkleRoot.opaque()) return false + if (merkleRoot.sequence() != other.merkleRoot.sequence()) return false if (publicKey != other.publicKey) return false return true } diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt b/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt index f1262f84af..472a8a1024 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt @@ -23,7 +23,8 @@ open class SignedData(val raw: SerializedBytes, val sig: DigitalSign @Throws(SignatureException::class) fun verified(): T { sig.by.verify(raw.bytes, sig) - val data = raw.deserialize() + @Suppress("UNCHECKED_CAST") + val data = raw.deserialize() as T verifyData(data) return data } diff --git a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt index 9e865ff098..56035333bb 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt @@ -4,13 +4,12 @@ import net.corda.core.crypto.Crypto import net.corda.core.crypto.composite.CompositeKey.NodeAndWeight import net.corda.core.crypto.keys import net.corda.core.crypto.provider.CordaObjectIdentifier -import net.corda.core.crypto.toSHA256Bytes import net.corda.core.crypto.toStringShort import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.sequence import org.bouncycastle.asn1.* import org.bouncycastle.asn1.x509.AlgorithmIdentifier import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo -import java.nio.ByteBuffer import java.security.PublicKey import java.util.* @@ -145,7 +144,7 @@ class CompositeKey private constructor(val threshold: Int, children: List() { + private var v: T? = null + + operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: throw IllegalStateException("Write-once property $property not set.") + + operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { + check(v == null) { "Cannot set write-once property $property more than once." } + v = value + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt index f91119f9e8..fa0da924d7 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt @@ -3,17 +3,16 @@ package net.corda.core.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoCallback -import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.util.MapReferenceResolver import com.google.common.annotations.VisibleForTesting import net.corda.core.contracts.* -import net.corda.core.crypto.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.MetaData +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureType import net.corda.core.crypto.composite.CompositeKey import net.corda.core.identity.Party import net.corda.core.transactions.WireTransaction -import net.corda.core.internal.LazyPool -import net.corda.core.utilities.OpaqueBytes import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec @@ -25,11 +24,8 @@ import org.bouncycastle.cert.X509CertificateHolder import org.slf4j.Logger import org.slf4j.LoggerFactory import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream import java.io.InputStream import java.lang.reflect.InvocationTargetException -import java.nio.file.Files -import java.nio.file.Path import java.security.PrivateKey import java.security.PublicKey import java.security.cert.CertPath @@ -77,58 +73,6 @@ import kotlin.reflect.jvm.javaType * TODO: eliminate internal, storage related whitelist issues, such as private keys in blob storage. */ -// A convenient instance of Kryo pre-configured with some useful things. Used as a default by various functions. -fun p2PKryo(): KryoPool = kryoPool - -// Same again, but this has whitelisting turned off for internal storage use only. -fun storageKryo(): KryoPool = internalKryoPool - - -/** - * A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize] - * to get the original object back. - */ -@Suppress("unused") // Type parameter is just for documentation purposes. -class SerializedBytes(bytes: ByteArray, val internalOnly: Boolean = false) : OpaqueBytes(bytes) { - // It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer. - val hash: SecureHash by lazy { bytes.sha256() } - - fun writeToFile(path: Path): Path = Files.write(path, bytes) -} - -// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB -private val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray()) - -// Some extension functions that make deserialisation convenient and provide auto-casting of the result. -fun ByteArray.deserialize(kryo: KryoPool = p2PKryo()): T { - Input(this).use { - val header = OpaqueBytes(it.readBytes(8)) - if (header != KryoHeaderV0_1) { - throw KryoException("Serialized bytes header does not match any known format.") - } - @Suppress("UNCHECKED_CAST") - return kryo.run { k -> k.readClassAndObject(it) as T } - } -} - -// TODO: The preferred usage is with a pool. Try and eliminate use of this from RPC. -fun ByteArray.deserialize(kryo: Kryo): T = deserialize(kryo.asPool()) - -fun OpaqueBytes.deserialize(kryo: KryoPool = p2PKryo()): T { - return this.bytes.deserialize(kryo) -} - -// The more specific deserialize version results in the bytes being cached, which is faster. -@JvmName("SerializedBytesWireTransaction") -fun SerializedBytes.deserialize(kryo: KryoPool = p2PKryo()): WireTransaction = WireTransaction.deserialize(this, kryo) - -fun SerializedBytes.deserialize(kryo: KryoPool = if (internalOnly) storageKryo() else p2PKryo()): T = bytes.deserialize(kryo) - -fun SerializedBytes.deserialize(kryo: Kryo): T = bytes.deserialize(kryo.asPool()) - -// Internal adapter for use when we haven't yet converted to a pool, or for tests. -private fun Kryo.asPool(): KryoPool = (KryoPool.Builder { this }.build()) - /** * A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure * type safety hack. @@ -144,36 +88,6 @@ object SerializedBytesSerializer : Serializer>() { } } -/** - * Can be called on any object to convert it to a byte array (wrapped by [SerializedBytes]), regardless of whether - * the type is marked as serializable or was designed for it (so be careful!). - */ -fun T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = false): SerializedBytes { - return kryo.run { k -> serialize(k, internalOnly) } -} - - -private val serializeBufferPool = LazyPool( - newInstance = { ByteArray(64 * 1024) } -) -private val serializeOutputStreamPool = LazyPool( - clear = ByteArrayOutputStream::reset, - shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large - newInstance = { ByteArrayOutputStream(64 * 1024) } -) -fun T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes { - return serializeOutputStreamPool.run { stream -> - serializeBufferPool.run { buffer -> - Output(buffer).use { - it.outputStream = stream - it.writeBytes(KryoHeaderV0_1.bytes) - kryo.writeClassAndObject(it, this) - } - SerializedBytes(stream.toByteArray(), internalOnly) - } - } -} - /** * Serializes properties and deserializes by using the constructor. This assumes that all backed properties are * set via the constructor and the class is immutable. @@ -463,14 +377,6 @@ inline fun readListOfLength(kryo: Kryo, input: Input, minLen: Int = return list } -// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. -private val internalKryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeAllButBlacklistedClassResolver())) }.build() -private val kryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeStandardClassResolver())) }.build() - -// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. -@VisibleForTesting -fun createTestKryo(): Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) - /** * We need to disable whitelist checking during calls from our Kryo code to register a serializer, since it checks * for existing registrations and then will enter our [CordaClassResolver.getRegistration] method. @@ -649,25 +555,3 @@ object X509CertificateSerializer : Serializer() { output.writeBytes(obj.encoded) } } - -class KryoPoolWithContext(val baseKryoPool: KryoPool, val contextKey: Any, val context: Any) : KryoPool { - override fun run(callback: KryoCallback): T { - val kryo = borrow() - try { - return callback.execute(kryo) - } finally { - release(kryo) - } - } - - override fun borrow(): Kryo { - val kryo = baseKryoPool.borrow() - require(kryo.context.put(contextKey, context) == null) { "KryoPool already has context" } - return kryo - } - - override fun release(kryo: Kryo) { - requireNotNull(kryo.context.remove(contextKey)) { "Kryo instance lost context while borrowed" } - baseKryoPool.release(kryo) - } -} diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt new file mode 100644 index 0000000000..7ecb71f160 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -0,0 +1,141 @@ +package net.corda.core.serialization + +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.sha256 +import net.corda.core.internal.WriteOnceProperty +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.sequence + +/** + * An abstraction for serializing and deserializing objects, with support for versioning of the wire format via + * a header / prefix in the bytes. + */ +interface SerializationFactory { + /** + * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. + * + * @param byteSequence The bytes to deserialize, including a format header prefix. + * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. + * @param context A context that configures various parameters to deserialization. + */ + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T + + /** + * Serialize an object to bytes using the preferred serialization format version from the context. + * + * @param obj The object to be serialized. + * @param context A context that configures various parameters to serialization, including the serialization format version. + */ + fun serialize(obj: T, context: SerializationContext): SerializedBytes +} + +/** + * Parameters to serialization and deserialization. + */ +interface SerializationContext { + /** + * When serializing, use the format this header sequence represents. + */ + val preferedSerializationVersion: ByteSequence + /** + * The class loader to use for deserialization. + */ + val deserializationClassLoader: ClassLoader + /** + * A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized. + */ + val whitelist: ClassWhitelist + /** + * A map of any addition properties specific to the particular use case. + */ + val properties: Map + /** + * Duplicate references to the same object preserved in the wire format and when deserialized when this is true, + * otherwise they appear as new copies of the object. + */ + val objectReferencesEnabled: Boolean + /** + * The use case we are serializing or deserializing for. See [UseCase]. + */ + val useCase: UseCase + /** + * Helper method to return a new context based on this context with the property added. + */ + fun withProperty(property: Any, value: Any): SerializationContext + + /** + * Helper method to return a new context based on this context with object references disabled. + */ + fun withoutReferences(): SerializationContext + + /** + * Helper method to return a new context based on this context with the deserialization class loader changed. + */ + fun withClassLoader(classLoader: ClassLoader): SerializationContext + + /** + * Helper method to return a new context based on this context with the given class specifically whitelisted. + */ + fun withWhitelisted(clazz: Class<*>): SerializationContext + + /** + * The use case that we are serializing for, since it influences the implementations chosen. + */ + enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint } +} + +/** + * Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client). + */ +object SerializationDefaults { + var SERIALIZATION_FACTORY: SerializationFactory by WriteOnceProperty() + var P2P_CONTEXT: SerializationContext by WriteOnceProperty() + var RPC_SERVER_CONTEXT: SerializationContext by WriteOnceProperty() + var RPC_CLIENT_CONTEXT: SerializationContext by WriteOnceProperty() + var STORAGE_CONTEXT: SerializationContext by WriteOnceProperty() + var CHECKPOINT_CONTEXT: SerializationContext by WriteOnceProperty() +} + +/** + * Convenience extension method for deserializing a ByteSequence, utilising the defaults. + */ +inline fun ByteSequence.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing SerializedBytes with type matching, utilising the defaults. + */ +inline fun SerializedBytes.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing a ByteArray, utilising the defaults. + */ +inline fun ByteArray.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T = this.sequence().deserialize(serializationFactory, context) + +/** + * Convenience extension method for serializing an object of type T, utilising the defaults. + */ +fun T.serialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): SerializedBytes { + return serializationFactory.serialize(this, context) +} + +/** + * A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize] + * to get the original object back. + */ +@Suppress("unused") // Type parameter is just for documentation purposes. +class SerializedBytes(bytes: ByteArray) : OpaqueBytes(bytes) { + // It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer. + val hash: SecureHash by lazy { bytes.sha256() } +} + +// The more specific deserialize version results in the bytes being cached, which is faster. +@JvmName("SerializedBytesWireTransaction") +fun SerializedBytes.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction = WireTransaction.deserialize(this, serializationFactory, context) diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt index c141435a4e..86d4fdfa1b 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt @@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.node.ServiceHub import net.corda.core.serialization.SingletonSerializationToken.Companion.singletonSerializationToken @@ -57,17 +56,9 @@ class SerializeAsTokenSerializer : Serializer() { private val serializationContextKey = SerializeAsTokenContext::class.java -fun Kryo.serializationContext() = context.get(serializationContextKey) as? SerializeAsTokenContext +fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext) -fun Kryo.withSerializationContext(serializationContext: SerializeAsTokenContext, block: () -> T) = run { - context.containsKey(serializationContextKey) && throw IllegalStateException("There is already a serialization context.") - context.put(serializationContextKey, serializationContext) - try { - block() - } finally { - context.remove(serializationContextKey) - } -} +fun Kryo.serializationContext(): SerializeAsTokenContext? = context.get(serializationContextKey) as? SerializeAsTokenContext /** * A context for mapping SerializationTokens to/from SerializeAsTokens. @@ -79,12 +70,8 @@ fun Kryo.withSerializationContext(serializationContext: SerializeAsTokenCont * on the Kryo instance when serializing to enable/disable tokenization. */ class SerializeAsTokenContext internal constructor(val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) { - constructor(toBeTokenized: Any, kryoPool: KryoPool, serviceHub: ServiceHub) : this(serviceHub, { - kryoPool.run { kryo -> - kryo.withSerializationContext(this) { - toBeTokenized.serialize(kryo) - } - } + constructor(toBeTokenized: Any, serializationFactory: SerializationFactory, context: SerializationContext, serviceHub: ServiceHub) : this(serviceHub, { + serializationFactory.serialize(toBeTokenized, context.withTokenContext(this)) }) private val classNameToSingleton = mutableMapOf() diff --git a/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt index 0e611fb242..29f1c3d38b 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt @@ -7,14 +7,13 @@ import net.corda.core.crypto.PartialMerkleTree import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.p2PKryo +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT import net.corda.core.serialization.serialize -import net.corda.core.serialization.withoutReferences import java.security.PublicKey import java.util.function.Predicate fun serializedHash(x: T): SecureHash { - return p2PKryo().run { kryo -> kryo.withoutReferences { x.serialize(kryo).hash } } + return x.serialize(context = P2P_CONTEXT.withoutReferences()).hash } /** diff --git a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt index af16bd728a..6715365840 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt @@ -1,6 +1,5 @@ package net.corda.core.transactions -import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.contracts.* import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.MerkleTree @@ -9,10 +8,9 @@ import net.corda.core.crypto.keys import net.corda.core.identity.Party import net.corda.core.indexOfOrThrow import net.corda.core.node.ServicesForResolution -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.p2PKryo -import net.corda.core.serialization.serialize +import net.corda.core.serialization.* +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY import net.corda.core.internal.Emoji import java.security.PublicKey import java.security.SignatureException @@ -48,8 +46,8 @@ class WireTransaction( override val id: SecureHash by lazy { merkleTree.hash } companion object { - fun deserialize(data: SerializedBytes, kryo: KryoPool = p2PKryo()): WireTransaction { - val wtx = data.bytes.deserialize(kryo) + fun deserialize(data: SerializedBytes, serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction { + val wtx = data.deserialize(serializationFactory, context) wtx.cachedBytes = data return wtx } diff --git a/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt b/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt index 3102086b43..866170ed4f 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt @@ -5,39 +5,143 @@ package net.corda.core.utilities import com.google.common.io.BaseEncoding import net.corda.core.serialization.CordaSerializable import java.io.ByteArrayInputStream -import java.util.* + +/** + * An abstraction of a byte array, with offset and size that does no copying of bytes unless asked to. + * + * The data of interest typically starts at position [offset] within the [bytes] and is [size] bytes long. + */ +@CordaSerializable +sealed class ByteSequence : Comparable { + /** + * The underlying bytes. + */ + abstract val bytes: ByteArray + /** + * The number of bytes this sequence represents. + */ + abstract val size: Int + /** + * The start position of the sequence within the byte array. + */ + abstract val offset: Int + /** Returns a [ByteArrayInputStream] of the bytes */ + fun open() = ByteArrayInputStream(bytes, offset, size) + + /** + * Create a sub-sequence backed by the same array. + * + * @param offset The offset within this sequence to start the new sequence. Note: not the offset within the backing array. + * @param size The size of the intended sub sequence. + */ + fun subSequence(offset: Int, size: Int): ByteSequence { + require(offset >= 0) + require(offset + size <= this.size) + return if (offset == 0 && size == this.size) this else of(bytes, this.offset + offset, size) + } + + companion object { + /** + * Construct a [ByteSequence] given a [ByteArray] and optional offset and size, that represents that potentially + * sub-sequence of bytes. The returned implementation is optimised when the whole [ByteArray] is the sequence. + */ + @JvmStatic + @JvmOverloads + fun of(bytes: ByteArray, offset: Int = 0, size: Int = bytes.size): ByteSequence { + return if (offset == 0 && size == bytes.size && size != 0) OpaqueBytes(bytes) else OpaqueBytesSubSequence(bytes, offset, size) + } + } + + /** + * Take the first n bytes of this sequence as a sub-sequence. See [subSequence] for further semantics. + */ + fun take(n: Int): ByteSequence { + require(size >= n) + return subSequence(0, n) + } + + /** + * Copy this sequence, complete with new backing array. This can be helpful to break references to potentially + * large backing arrays from small sub-sequences. + */ + fun copy(): ByteSequence = of(bytes.copyOfRange(offset, offset + size)) + + /** + * Compare byte arrays byte by byte. Arrays that are shorter are deemed less than longer arrays if all the bytes + * of the shorter array equal those in the same position of the longer array. + */ + override fun compareTo(other: ByteSequence): Int { + val min = minOf(this.size, other.size) + // Compare min bytes + for (index in 0 until min) { + val unsignedThis = java.lang.Byte.toUnsignedInt(this.bytes[this.offset + index]) + val unsignedOther = java.lang.Byte.toUnsignedInt(other.bytes[other.offset + index]) + if (unsignedThis != unsignedOther) { + return Integer.signum(unsignedThis - unsignedOther) + } + } + // First min bytes is the same, so now resort to size + return Integer.signum(this.size - other.size) + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ByteSequence) return false + if (this.size != other.size) return false + return subArraysEqual(this.bytes, this.offset, this.size, other.bytes, other.offset) + } + + private fun subArraysEqual(a: ByteArray, aOffset: Int, length: Int, b: ByteArray, bOffset: Int): Boolean { + var bytesRemaining = length + var aPos = aOffset + var bPos = bOffset + while (bytesRemaining-- > 0) { + if (a[aPos++] != b[bPos++]) return false + } + return true + } + + override fun hashCode(): Int { + var result = 1 + for (index in offset until (offset + size)) { + result = 31 * result + bytes[index] + } + return result + } + + override fun toString(): String = "[${BaseEncoding.base16().encode(bytes, offset, size)}]" +} /** * A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect. * In an ideal JVM this would be a value type and be completely overhead free. Project Valhalla is adding such * functionality to Java, but it won't arrive for a few years yet! */ -@CordaSerializable -open class OpaqueBytes(val bytes: ByteArray) { +open class OpaqueBytes(override val bytes: ByteArray) : ByteSequence() { companion object { @JvmStatic fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b)) } init { - check(bytes.isNotEmpty()) + require(bytes.isNotEmpty()) } - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is OpaqueBytes) return false - return Arrays.equals(bytes, other.bytes) - } - - override fun hashCode() = Arrays.hashCode(bytes) - override fun toString() = "[" + bytes.toHexString() + "]" - - val size: Int get() = bytes.size - - /** Returns a [ByteArrayInputStream] of the bytes */ - fun open() = ByteArrayInputStream(bytes) + override val size: Int get() = bytes.size + override val offset: Int get() = 0 } +@Deprecated("Use sequence instead") fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this) + +fun ByteArray.sequence(offset: Int = 0, size: Int = this.size) = ByteSequence.of(this, offset, size) + fun ByteArray.toHexString(): String = BaseEncoding.base16().encode(this) fun String.parseAsHex(): ByteArray = BaseEncoding.base16().decode(this) + +private class OpaqueBytesSubSequence(override val bytes: ByteArray, override val offset: Int, override val size: Int) : ByteSequence() { + init { + require(offset >= 0 && offset < bytes.size) + require(size >= 0 && size <= bytes.size) + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt index 9e6be70bec..7352579be3 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt @@ -7,7 +7,6 @@ import net.corda.core.transactions.LedgerTransaction import net.corda.testing.MEGA_CORP import net.corda.testing.MINI_CORP import net.corda.testing.ledger -import net.corda.testing.transaction import org.junit.Test import java.time.Instant import java.time.temporal.ChronoUnit @@ -115,22 +114,26 @@ class TransactionEncumbranceTests { @Test fun `state cannot be encumbered by itself`() { - transaction { - input { state } - output(encumbrance = 0) { stateWithNewOwner } - command(MEGA_CORP.owningKey) { Cash.Commands.Move() } - this `fails with` "Missing required encumbrance 0 in OUTPUT" + ledger { + transaction { + input { state } + output(encumbrance = 0) { stateWithNewOwner } + command(MEGA_CORP.owningKey) { Cash.Commands.Move() } + this `fails with` "Missing required encumbrance 0 in OUTPUT" + } } } @Test fun `encumbrance state index must be valid`() { - transaction { - input { state } - output(encumbrance = 2) { stateWithNewOwner } - output { timeLock } - command(MEGA_CORP.owningKey) { Cash.Commands.Move() } - this `fails with` "Missing required encumbrance 2 in OUTPUT" + ledger { + transaction { + input { state } + output(encumbrance = 2) { stateWithNewOwner } + output { timeLock } + command(MEGA_CORP.owningKey) { Cash.Commands.Move() } + this `fails with` "Missing required encumbrance 2 in OUTPUT" + } } } diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt index 8f55aa7317..0fa477f166 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt @@ -1,20 +1,17 @@ package net.corda.core.contracts -import net.corda.testing.contracts.DummyContract -import net.corda.testing.contracts.DummyState import net.corda.core.crypto.newSecureRandom import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.MEGA_CORP_KEY -import net.corda.testing.MEGA_CORP_PUBKEY +import net.corda.testing.* +import net.corda.testing.contracts.DummyContract +import net.corda.testing.contracts.DummyState import net.corda.testing.node.MockServices import net.corda.testing.node.MockTransactionStorage import org.junit.Test import kotlin.test.assertEquals -class TransactionGraphSearchTests { +class TransactionGraphSearchTests : TestDependencyInjectionBase() { class GraphTransactionStorage(val originTx: SignedTransaction, val inputTx: SignedTransaction) : MockTransactionStorage() { init { addTransaction(originTx) diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt index 400fc7fff7..36f37a6acf 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt @@ -1,9 +1,8 @@ package net.corda.core.contracts import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.sign import net.corda.core.identity.Party @@ -12,13 +11,13 @@ import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.testing.* +import net.corda.testing.contracts.DummyContract import org.junit.Test import java.security.KeyPair import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class TransactionTests { - +class TransactionTests : TestDependencyInjectionBase() { private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair): SignedTransaction { val bytes: SerializedBytes = wtx.serialized return SignedTransaction(bytes, keys.map { it.sign(wtx.id.bytes) }) diff --git a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt index 9498f15b03..2da44f4aa5 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt @@ -11,6 +11,7 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.node.utilities.loadKeyStore import net.corda.node.utilities.loadOrCreateKeyStore import net.corda.node.utilities.save +import net.corda.testing.TestDependencyInjectionBase import org.bouncycastle.asn1.x500.X500Name import org.junit.Rule import org.junit.Test @@ -21,7 +22,7 @@ import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertTrue -class CompositeKeyTests { +class CompositeKeyTests : TestDependencyInjectionBase() { @Rule @JvmField val tempFolder: TemporaryFolder = TemporaryFolder() diff --git a/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt b/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt index 6b69c32e13..b333fe0ba2 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt @@ -6,21 +6,24 @@ import net.corda.contracts.asset.Cash import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash.Companion.zeroHash import net.corda.core.identity.Party -import net.corda.core.serialization.p2PKryo import net.corda.core.serialization.serialize import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_PUBKEY_1 -import net.corda.testing.TEST_TX_TIME import net.corda.testing.* import org.junit.Test import java.security.PublicKey import java.util.function.Predicate import kotlin.test.* -class PartialMerkleTreeTest { +class PartialMerkleTreeTest : TestDependencyInjectionBase() { val nodes = "abcdef" - val hashed = nodes.map { it.serialize().sha256() } + val hashed = nodes.map { + initialiseTestSerialization() + try { + it.serialize().sha256() + } finally { + resetTestSerialization() + } + } val expectedRoot = MerkleTree.getMerkleTree(hashed.toMutableList() + listOf(zeroHash, zeroHash)).hash val merkleTree = MerkleTree.getMerkleTree(hashed) @@ -215,9 +218,7 @@ class PartialMerkleTreeTest { @Test(expected = KryoException::class) fun `hash map serialization not allowed`() { val hm1 = hashMapOf("a" to 1, "b" to 2, "c" to 3, "e" to 4) - p2PKryo().run { kryo -> - hm1.serialize(kryo) - } + hm1.serialize() } private fun makeSimpleCashWtx(notary: Party, timeWindow: TimeWindow? = null, attachments: List = emptyList()): WireTransaction { diff --git a/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt b/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt index cb83d847da..c8f35a77a5 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt @@ -1,13 +1,21 @@ package net.corda.core.crypto +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.serialize +import net.corda.testing.TestDependencyInjectionBase +import org.junit.Before import org.junit.Test import java.security.SignatureException import kotlin.test.assertEquals -class SignedDataTest { +class SignedDataTest : TestDependencyInjectionBase() { + @Before + fun initialise() { + serialized = data.serialize() + } + val data = "Just a simple test string" - val serialized = data.serialize() + lateinit var serialized: SerializedBytes @Test fun `make sure correctly signed data is released`() { diff --git a/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt b/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt index 4aad5b6580..6ef8aa9921 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt @@ -1,5 +1,6 @@ package net.corda.core.crypto +import net.corda.testing.TestDependencyInjectionBase import org.junit.Test import java.security.SignatureException import java.time.Instant @@ -8,7 +9,7 @@ import kotlin.test.assertTrue /** * Digital signature MetaData tests */ -class TransactionSignatureTest { +class TransactionSignatureTest : TestDependencyInjectionBase() { val testBytes = "12345678901234567890123456789012".toByteArray() diff --git a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt index acac7c2574..f7f7e126f3 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt @@ -116,7 +116,7 @@ class ContractUpgradeFlowTest { @Test fun `2 parties contract upgrade using RPC`() { - rpcDriver { + rpcDriver(initialiseSerialization = false) { // Create dummy contract. val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1)) val signedByA = a.services.signInitialTransaction(twoPartyDummyContract) diff --git a/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt b/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt index b11d93564f..27c857714c 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt @@ -1,15 +1,15 @@ package net.corda.core.flows -import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.SecureHash import net.corda.core.getOrThrow import net.corda.core.identity.Party -import net.corda.core.utilities.opaque import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.sequence import net.corda.testing.DUMMY_NOTARY_KEY import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP_KEY import net.corda.testing.MINI_CORP +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockServices import org.junit.After @@ -17,7 +17,6 @@ import org.junit.Before import org.junit.Test import java.io.ByteArrayOutputStream import java.io.InputStream -import java.security.SignatureException import java.util.jar.JarEntry import java.util.jar.JarOutputStream import kotlin.test.assertEquals @@ -141,7 +140,7 @@ class ResolveTransactionsFlowTest { jar.write("Some test file".toByteArray()) jar.closeEntry() jar.close() - return bs.toByteArray().opaque().open() + return bs.toByteArray().sequence().open() } // TODO: this operation should not require an explicit transaction val id = a.database.transaction { diff --git a/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt index b91e1a2744..2bb1c0f5f7 100644 --- a/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt @@ -8,8 +8,6 @@ import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.testing.DUMMY_NOTARY import net.corda.testing.node.MockNetwork -import org.junit.After -import org.junit.Before import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -17,22 +15,10 @@ import kotlin.test.assertNotEquals import kotlin.test.assertTrue class TransactionKeyFlowTests { - lateinit var mockNet: MockNetwork - - @Before - fun before() { - mockNet = MockNetwork(false) - } - - @After - fun cleanUp() { - mockNet.stopNodes() - } - @Test fun `issue key`() { // We run this in parallel threads to help catch any race conditions that may exist. - mockNet = MockNetwork(false, true) + val mockNet = MockNetwork(false, true) // Set up values we'll need val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) @@ -66,5 +52,7 @@ class TransactionKeyFlowTests { assertTrue { bobAnonymousIdentity.party.owningKey in bobNode.services.keyManagementService.keys } assertFalse { aliceAnonymousIdentity.party.owningKey in bobNode.services.keyManagementService.keys } assertFalse { bobAnonymousIdentity.party.owningKey in aliceNode.services.keyManagementService.keys } + + mockNet.stopNodes() } } diff --git a/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt b/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt index efe0a254c0..33d64c768d 100644 --- a/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt @@ -1,6 +1,5 @@ package net.corda.core.node -import com.esotericsoftware.kryo.Kryo import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever import net.corda.core.contracts.* @@ -9,14 +8,15 @@ import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.node.services.AttachmentStorage import net.corda.core.serialization.* +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.testing.DUMMY_NOTARY import net.corda.testing.MEGA_CORP +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.MockAttachmentStorage import org.apache.commons.io.IOUtils import org.junit.Assert -import org.junit.Before import org.junit.Test import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream @@ -36,15 +36,14 @@ interface DummyContractBackdoor { val ATTACHMENT_TEST_PROGRAM_ID = AttachmentClassLoaderTests.AttachmentDummyContract() -class AttachmentClassLoaderTests { +class AttachmentClassLoaderTests : TestDependencyInjectionBase() { companion object { val ISOLATED_CONTRACTS_JAR_PATH: URL = AttachmentClassLoaderTests::class.java.getResource("isolated.jar") - private fun Kryo.withAttachmentStorage(attachmentStorage: AttachmentStorage, block: () -> T) = run { - context.put(WireTransactionSerializer.attachmentsClassLoaderEnabled, true) + private fun SerializationContext.withAttachmentStorage(attachmentStorage: AttachmentStorage): SerializationContext { val serviceHub = mock() whenever(serviceHub.attachments).thenReturn(attachmentStorage) - withSerializationContext(SerializeAsTokenContext(serviceHub) {}, block) + return this.withTokenContext(SerializeAsTokenContext(serviceHub) {}).withProperty(WireTransactionSerializer.attachmentsClassLoaderEnabled, true) } } @@ -89,16 +88,6 @@ class AttachmentClassLoaderTests { class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader) - lateinit var kryo: Kryo - lateinit var kryo2: Kryo - - @Before - fun setup() { - // Do not release these back to the pool, since we do some unorthodox modifications to them below. - kryo = p2PKryo().borrow() - kryo2 = p2PKryo().borrow() - } - @Test fun `dynamically load AnotherDummyContract from isolated contracts jar`() { val child = ClassLoaderForTests() @@ -229,10 +218,8 @@ class AttachmentClassLoaderTests { val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl - kryo.addToWhitelist(contract.javaClass) - - val state2 = bytes.deserialize(kryo) + val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(contract.javaClass) + val state2 = bytes.deserialize(context = context) assertTrue(state2.javaClass.classLoader is AttachmentsClassLoader) assertNotNull(state2) } @@ -247,8 +234,9 @@ class AttachmentClassLoaderTests { assertNotNull(data.contract) - kryo2.addToWhitelist(data.contract.javaClass) - val bytes = data.serialize(kryo2) + val context2 = P2P_CONTEXT.withWhitelisted(data.contract.javaClass) + + val bytes = data.serialize(context = context2) val storage = MockAttachmentStorage() @@ -258,20 +246,18 @@ class AttachmentClassLoaderTests { val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)) + val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)) - val state2 = bytes.deserialize(kryo) + val state2 = bytes.deserialize(context = context) assertEquals(cl, state2.contract.javaClass.classLoader) assertNotNull(state2) // We should be able to load same class from a different class loader and have them be distinct. val cl2 = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl2 - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2)) + val context3 = P2P_CONTEXT.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2)) - val state3 = bytes.deserialize(kryo) + val state3 = bytes.deserialize(context = context3) assertEquals(cl2, state3.contract.javaClass.classLoader) assertNotNull(state3) } @@ -295,30 +281,22 @@ class AttachmentClassLoaderTests { val contract = contractClass.newInstance() as DummyContractBackdoor val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY) val storage = MockAttachmentStorage() - kryo.addToWhitelist(contract.javaClass) - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child)) - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child)) + val context = P2P_CONTEXT.withWhitelisted(contract.javaClass) + .withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child)) + .withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child)) + .withAttachmentStorage(storage) // todo - think about better way to push attachmentStorage down to serializer - val bytes = kryo.withAttachmentStorage(storage) { - + val bytes = run { val attachmentRef = importJar(storage) - tx.addAttachment(storage.openAttachment(attachmentRef)!!.id) - val wireTransaction = tx.toWireTransaction() - - wireTransaction.serialize(kryo) - } - // use empty attachmentStorage - kryo2.withAttachmentStorage(storage) { - - val copiedWireTransaction = bytes.deserialize(kryo2) - - assertEquals(1, copiedWireTransaction.outputs.size) - val contract2 = copiedWireTransaction.outputs[0].data.contract as DummyContractBackdoor - assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data)) + wireTransaction.serialize(context = context) } + val copiedWireTransaction = bytes.deserialize(context = context) + assertEquals(1, copiedWireTransaction.outputs.size) + val contract2 = copiedWireTransaction.outputs[0].data.contract as DummyContractBackdoor + assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data)) } @Test @@ -331,21 +309,19 @@ class AttachmentClassLoaderTests { // todo - think about better way to push attachmentStorage down to serializer val attachmentRef = importJar(storage) - val bytes = kryo.withAttachmentStorage(storage) { + val bytes = run { tx.addAttachment(storage.openAttachment(attachmentRef)!!.id) val wireTransaction = tx.toWireTransaction() - wireTransaction.serialize(kryo) + wireTransaction.serialize(context = P2P_CONTEXT.withAttachmentStorage(storage)) } // use empty attachmentStorage - kryo2.withAttachmentStorage(MockAttachmentStorage()) { - val e = assertFailsWith(MissingAttachmentsException::class) { - bytes.deserialize(kryo2) - } - assertEquals(attachmentRef, e.ids.single()) + val e = assertFailsWith(MissingAttachmentsException::class) { + bytes.deserialize(context = P2P_CONTEXT.withAttachmentStorage(MockAttachmentStorage())) } + assertEquals(attachmentRef, e.ids.single()) } } diff --git a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt b/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt index e27932c757..e227564aed 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt @@ -1,10 +1,13 @@ package net.corda.core.serialization -import com.esotericsoftware.kryo.Kryo import com.google.common.primitives.Ints import net.corda.core.crypto.* -import net.corda.core.utilities.opaque +import net.corda.core.utilities.sequence +import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.services.persistence.NodeAttachmentService +import net.corda.nodeapi.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.serialization.SerializationContextImpl +import net.corda.nodeapi.serialization.SerializationFactoryImpl import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.testing.BOB_PUBKEY @@ -19,62 +22,66 @@ import java.io.InputStream import java.security.cert.CertPath import java.security.cert.CertificateFactory import java.time.Instant -import java.util.* import kotlin.test.assertEquals import kotlin.test.assertTrue class KryoTests { - - private lateinit var kryo: Kryo + private lateinit var factory: SerializationFactory + private lateinit var context: SerializationContext @Before fun setup() { - // We deliberately do not return this, since we do some unorthodox registering below and do not want to pollute the pool. - kryo = p2PKryo().borrow() + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) } @Test fun ok() { val birthday = Instant.parse("1984-04-17T00:30:00.00Z") val mike = Person("mike", birthday) - val bits = mike.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(Person("mike", birthday)) + val bits = mike.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday)) } @Test fun nullables() { val bob = Person("bob", null) - val bits = bob.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(Person("bob", null)) + val bits = bob.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null)) } @Test fun `serialised form is stable when the same object instance is added to the deserialised object graph`() { - kryo.noReferencesWithin>() - val obj = Ints.toByteArray(0x01234567).opaque() + val noReferencesContext = context.withoutReferences() + val obj = Ints.toByteArray(0x01234567).sequence() val originalList = arrayListOf(obj) - val deserialisedList = originalList.serialize(kryo).deserialize(kryo) + val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext) originalList += obj deserialisedList += obj - assertThat(deserialisedList.serialize(kryo)).isEqualTo(originalList.serialize(kryo)) + assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext)) } @Test fun `serialised form is stable when the same object instance occurs more than once, and using java serialisation`() { - kryo.noReferencesWithin>() + val noReferencesContext = context.withoutReferences() val instant = Instant.ofEpochMilli(123) val instantCopy = Instant.ofEpochMilli(123) assertThat(instant).isNotSameAs(instantCopy) val listWithCopies = arrayListOf(instant, instantCopy) val listWithSameInstances = arrayListOf(instant, instant) - assertThat(listWithSameInstances.serialize(kryo)).isEqualTo(listWithCopies.serialize(kryo)) + assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext)) } @Test fun `cyclic object graph`() { val cyclic = Cyclic(3) - val bits = cyclic.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(cyclic) + val bits = cyclic.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic) } @Test @@ -86,7 +93,7 @@ class KryoTests { signature.verify(bitsToSign) assertThatThrownBy { signature.verify(wrongBits) } - val deserialisedKeyPair = keyPair.serialize(kryo).deserialize(kryo) + val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context) val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) deserialisedSignature.verify(bitsToSign) assertThatThrownBy { deserialisedSignature.verify(wrongBits) } @@ -94,15 +101,15 @@ class KryoTests { @Test fun `write and read Kotlin object singleton`() { - val serialised = TestSingleton.serialize(kryo) - val deserialised = serialised.deserialize(kryo) + val serialised = TestSingleton.serialize(factory, context) + val deserialised = serialised.deserialize(factory, context) assertThat(deserialised).isSameAs(TestSingleton) } @Test fun `InputStream serialisation`() { val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) - val readRubbishStream: InputStream = rubbish.inputStream().serialize(kryo).deserialize(kryo) + val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -118,15 +125,16 @@ class KryoTests { bitSet.set(3) val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), bitSet, bitSet, testBytes, keyPair1.public) - val serializedMetaData = meta.bytes() - val meta2 = serializedMetaData.deserialize() + val serializedMetaData = meta.serialize(factory, context).bytes + val meta2 = serializedMetaData.deserialize(factory, context) assertEquals(meta2, meta) } @Test fun `serialize - deserialize Logger`() { + val storageContext: SerializationContext = context // TODO: make it storage context val logger = LoggerFactory.getLogger("aName") - val logger2 = logger.serialize(storageKryo()).deserialize(storageKryo()) + val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext) assertEquals(logger.name, logger2.name) assertTrue(logger === logger2) } @@ -134,7 +142,7 @@ class KryoTests { @Test fun `HashCheckingStream (de)serialize`() { val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) - val readRubbishStream: InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(kryo).deserialize(kryo) + val readRubbishStream: InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(factory, context).deserialize(factory, context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -144,8 +152,8 @@ class KryoTests { @Test fun `serialize - deserialize X509CertififcateHolder`() { val expected: X509CertificateHolder = X509Utilities.createSelfSignedCACertificate(ALICE.name, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) - val serialized = expected.serialize(kryo).bytes - val actual: X509CertificateHolder = serialized.deserialize(kryo) + val serialized = expected.serialize(factory, context).bytes + val actual: X509CertificateHolder = serialized.deserialize(factory, context) assertEquals(expected, actual) } @@ -156,8 +164,8 @@ class KryoTests { val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootCAKey) val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB.name, BOB_PUBKEY) val expected = certFactory.generateCertPath(listOf(certificate.cert, rootCACert.cert)) - val serialized = expected.serialize(kryo).bytes - val actual: CertPath = serialized.deserialize(kryo) + val serialized = expected.serialize(factory, context).bytes + val actual: CertPath = serialized.deserialize(factory, context) assertEquals(expected, actual) } diff --git a/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt b/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt index 9b2517c8d5..0f48a5fb8c 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt @@ -6,24 +6,29 @@ import com.esotericsoftware.kryo.io.Output import com.nhaarman.mockito_kotlin.mock import net.corda.core.node.ServiceHub import net.corda.core.utilities.OpaqueBytes +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.nodeapi.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.serialization.SerializationContextImpl +import net.corda.nodeapi.serialization.SerializationFactoryImpl import org.assertj.core.api.Assertions.assertThat -import org.junit.After import org.junit.Before import org.junit.Test import java.io.ByteArrayOutputStream class SerializationTokenTest { - lateinit var kryo: Kryo + lateinit var factory: SerializationFactory + lateinit var context: SerializationContext @Before fun setup() { - kryo = storageKryo().borrow() - } - - @After - fun cleanup() { - storageKryo().release(kryo) + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) } // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized @@ -38,20 +43,18 @@ class SerializationTokenTest { override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size } - companion object { - private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContext(toBeTokenized, storageKryo(), mock()) - } + private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContext(toBeTokenized, factory, context, mock()) @Test fun `write token and read tokenizable`() { val tokenizableBefore = LargeTokenizable() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) - val tokenizableAfter = serializedBytes.deserialize(kryo) - assertThat(tokenizableAfter).isSameAs(tokenizableBefore) - } + val testContext = this.context.withTokenContext(context) + + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) + val tokenizableAfter = serializedBytes.deserialize(factory, testContext) + assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } private class UnitSerializeAsToken : SingletonSerializeAsToken() @@ -60,51 +63,50 @@ class SerializationTokenTest { fun `write and read singleton`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - val tokenizableAfter = serializedBytes.deserialize(kryo) + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + val tokenizableAfter = serializedBytes.deserialize(factory, testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) - } } @Test(expected = UnsupportedOperationException::class) fun `new token encountered after context init`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) - kryo.withSerializationContext(context) { - tokenizableBefore.serialize(kryo) - } + val testContext = this.context.withTokenContext(context) + tokenizableBefore.serialize(factory, testContext) } @Test(expected = UnsupportedOperationException::class) fun `deserialize unregistered token`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).serialize(kryo) - serializedBytes.deserialize(kryo) - } + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).serialize(factory, testContext) + serializedBytes.deserialize(factory, testContext) } @Test(expected = KryoException::class) fun `no context set`() { val tokenizableBefore = UnitSerializeAsToken() - tokenizableBefore.serialize(kryo) + tokenizableBefore.serialize(factory, context) } @Test(expected = KryoException::class) fun `deserialize non-token`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val stream = ByteArrayOutputStream() + val testContext = this.context.withTokenContext(context) + + val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) + val stream = ByteArrayOutputStream() Output(stream).use { + it.write(KryoHeaderV0_1.bytes) kryo.writeClass(it, SingletonSerializeAsToken::class.java) kryo.writeObject(it, emptyList()) } - val serializedBytes = SerializedBytes(stream.toByteArray()) - serializedBytes.deserialize(kryo) - } + val serializedBytes = SerializedBytes(stream.toByteArray()) + serializedBytes.deserialize(factory, testContext) } private class WrongTypeSerializeAsToken : SerializeAsToken { @@ -119,9 +121,8 @@ class SerializationTokenTest { fun `token returns unexpected type`() { val tokenizableBefore = WrongTypeSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - serializedBytes.deserialize(kryo) - } + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + serializedBytes.deserialize(factory, testContext) } } diff --git a/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt b/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt index 729ff16822..cbadb624a3 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt @@ -17,7 +17,7 @@ import kotlin.test.assertFailsWith val TEST_PROGRAM_ID = TransactionSerializationTests.TestCash() -class TransactionSerializationTests { +class TransactionSerializationTests : TestDependencyInjectionBase() { class TestCash : Contract { override val legalContractReference = SecureHash.sha256("TestCash") diff --git a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt index b58c74a51f..ae4a67c093 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt @@ -4,10 +4,11 @@ import net.corda.core.crypto.random63BitValue import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.testing.TestDependencyInjectionBase import org.assertj.core.api.Assertions.assertThat import org.junit.Test -class KotlinUtilsTest { +class KotlinUtilsTest : TestDependencyInjectionBase() { @Test fun `transient property which is null`() { val test = NullTransientProperty() diff --git a/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt b/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt index c6991c407a..299dec166e 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt @@ -7,6 +7,8 @@ import com.google.common.collect.testing.features.CollectionSize import junit.framework.TestSuite import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.testing.initialiseTestSerialization +import net.corda.testing.resetTestSerialization import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test @@ -47,9 +49,15 @@ class NonEmptySetTest { @Test fun `serialize deserialize`() { - val original = NonEmptySet.of(-17, 22, 17) - val copy = original.serialize().deserialize() - assertThat(copy).isEqualTo(original).isNotSameAs(original) + initialiseTestSerialization() + try { + val original = NonEmptySet.of(-17, 22, 17) + val copy = original.serialize().deserialize() + + assertThat(copy).isEqualTo(original).isNotSameAs(original) + } finally { + resetTestSerialization() + } } } diff --git a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt index 8b56b4ad13..55cc536086 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt @@ -4,8 +4,13 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import net.corda.core.serialization.createTestKryo +import net.corda.core.serialization.AllWhitelist +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.serialize +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.nodeapi.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.serialization.SerializationContextImpl +import net.corda.nodeapi.serialization.SerializationFactoryImpl import org.junit.Before import org.junit.Test import java.util.* @@ -106,10 +111,6 @@ class ProgressTrackerTest { } } - val kryo = createTestKryo().apply { - // This is required to make sure Kryo walks through the auto-generated members for the lambda below. - fieldSerializerConfig.isIgnoreSyntheticFields = false - } pt.setChildProgressTracker(SimpleSteps.TWO, pt2) class Tmp { val unserializable = Unserializable() @@ -119,6 +120,13 @@ class ProgressTrackerTest { } } Tmp() - pt.serialize(kryo) + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } + val context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) + pt.serialize(factory, context) } } diff --git a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt index e0e6f6be4a..bb8f9a84da 100644 --- a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt +++ b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt @@ -15,10 +15,10 @@ import net.corda.testing.BOB import net.corda.testing.DUMMY_NOTARY import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow -import net.corda.testing.driver.driver import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.nodeapi.User +import net.corda.testing.driver.driver import net.corda.testing.expect import net.corda.testing.expectEvents import net.corda.testing.parallel diff --git a/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt b/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt index 893d8ce3e4..f34fda4c0b 100644 --- a/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt @@ -209,7 +209,7 @@ class CommercialPaperTestsGeneric { @Test fun `issue move and then redeem`() { - + initialiseTestSerialization() val dataSourcePropsAlice = makeTestDataSourceProperties() val databaseAlice = configureDatabase(dataSourcePropsAlice) databaseAlice.transaction { @@ -307,5 +307,6 @@ class CommercialPaperTestsGeneric { validRedemption.toLedgerTransaction(aliceServices).verify() // soft lock not released after success either!!! (as transaction not recorded) } + resetTestSerialization() } } diff --git a/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt b/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt index 0fa711a408..04518c251f 100644 --- a/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt @@ -1,8 +1,6 @@ package net.corda.contracts.asset -import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyState import net.corda.core.crypto.SecureHash import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.AbstractParty @@ -10,13 +8,15 @@ import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.services.VaultService import net.corda.core.node.services.unconsumedStates -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.OpaqueBytes import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase import net.corda.testing.* +import net.corda.testing.contracts.DummyState +import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.node.MockKeyManagementService import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties @@ -26,7 +26,7 @@ import java.security.KeyPair import java.util.* import kotlin.test.* -class CashTests { +class CashTests : TestDependencyInjectionBase() { val defaultRef = OpaqueBytes(ByteArray(1, { 1 })) val defaultIssuer = MEGA_CORP.ref(defaultRef) val inState = Cash.State( @@ -76,6 +76,7 @@ class CashTests { vaultStatesUnconsumed = miniCorpServices.vaultService.unconsumedStates().toList() } + resetTestSerialization() } @Test @@ -152,6 +153,7 @@ class CashTests { @Test fun generateIssueRaw() { + initialiseTestSerialization() // Test generation works. val tx: WireTransaction = TransactionType.General.Builder(notary = null).apply { Cash().generateIssue(this, 100.DOLLARS `issued by` MINI_CORP.ref(12, 34), owner = AnonymousParty(DUMMY_PUBKEY_1), notary = DUMMY_NOTARY) @@ -167,6 +169,7 @@ class CashTests { @Test fun generateIssueFromAmount() { + initialiseTestSerialization() // Test issuance from an issued amount val amount = 100.DOLLARS `issued by` MINI_CORP.ref(12, 34) val tx: WireTransaction = TransactionType.General.Builder(notary = null).apply { @@ -239,6 +242,7 @@ class CashTests { */ @Test(expected = IllegalStateException::class) fun `reject issuance with inputs`() { + initialiseTestSerialization() // Issue some cash var ptx = TransactionType.General.Builder(DUMMY_NOTARY) @@ -490,6 +494,7 @@ class CashTests { */ @Test fun generateSimpleExit() { + initialiseTestSerialization() val wtx = makeExit(100.DOLLARS, MEGA_CORP, 1) assertEquals(WALLET[0].ref, wtx.inputs[0]) assertEquals(0, wtx.outputs.size) @@ -505,6 +510,7 @@ class CashTests { */ @Test fun generatePartialExit() { + initialiseTestSerialization() val wtx = makeExit(50.DOLLARS, MEGA_CORP, 1) assertEquals(WALLET[0].ref, wtx.inputs[0]) assertEquals(1, wtx.outputs.size) @@ -516,6 +522,7 @@ class CashTests { */ @Test fun generateAbsentExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, MEGA_CORP, 1) } } @@ -524,6 +531,7 @@ class CashTests { */ @Test fun generateInvalidReferenceExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, MEGA_CORP, 2) } } @@ -532,6 +540,7 @@ class CashTests { */ @Test fun generateInsufficientExit() { + initialiseTestSerialization() assertFailsWith { makeExit(1000.DOLLARS, MEGA_CORP, 1) } } @@ -540,6 +549,7 @@ class CashTests { */ @Test fun generateOwnerWithNoStatesExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, CHARLIE, 1) } } @@ -548,6 +558,7 @@ class CashTests { */ @Test fun generateExitWithEmptyVault() { + initialiseTestSerialization() assertFailsWith { val tx = TransactionType.General.Builder(DUMMY_NOTARY) Cash().generateExit(tx, Amount(100, Issued(CHARLIE.ref(1), GBP)), emptyList()) @@ -556,9 +567,8 @@ class CashTests { @Test fun generateSimpleDirectSpend() { - + initialiseTestSerialization() database.transaction { - val wtx = makeSpend(100.DOLLARS, THEIR_IDENTITY_1) @Suppress("UNCHECKED_CAST") @@ -571,7 +581,7 @@ class CashTests { @Test fun generateSimpleSpendWithParties() { - + initialiseTestSerialization() database.transaction { val tx = TransactionType.General.Builder(DUMMY_NOTARY) @@ -583,7 +593,7 @@ class CashTests { @Test fun generateSimpleSpendWithChange() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(10.DOLLARS, THEIR_IDENTITY_1) @@ -599,7 +609,7 @@ class CashTests { @Test fun generateSpendWithTwoInputs() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(500.DOLLARS, THEIR_IDENTITY_1) @@ -615,7 +625,7 @@ class CashTests { @Test fun generateSpendMixedDeposits() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(580.DOLLARS, THEIR_IDENTITY_1) assertEquals(3, wtx.inputs.size) @@ -636,7 +646,7 @@ class CashTests { @Test fun generateSpendInsufficientBalance() { - + initialiseTestSerialization() database.transaction { val e: InsufficientBalanceException = assertFailsWith("balance") { diff --git a/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt b/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt index fe3113b2af..e9f19c4db4 100644 --- a/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt @@ -13,6 +13,7 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.days import net.corda.core.utilities.hours import net.corda.testing.* +import org.junit.After import net.corda.testing.contracts.DummyState import net.corda.testing.node.MockServices import org.junit.Test @@ -57,6 +58,11 @@ class ObligationTests { } } + @After + fun reset() { + resetTestSerialization() + } + @Test fun trivial() { transaction { @@ -127,6 +133,7 @@ class ObligationTests { this.verifies() } + initialiseTestSerialization() // Test generation works. val tx = TransactionType.General.Builder(notary = null).apply { Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, @@ -142,6 +149,7 @@ class ObligationTests { assertEquals(tx.outputs[0].data, expected) assertTrue(tx.commands[0].value is Obligation.Commands.Issue) assertEquals(MINI_CORP_PUBKEY, tx.commands[0].signers[0]) + resetTestSerialization() // We can consume $1000 in a transaction and output $2000 as long as it's signed by an issuer. transaction { @@ -204,6 +212,7 @@ class ObligationTests { */ @Test(expected = IllegalStateException::class) fun `reject issuance with inputs`() { + initialiseTestSerialization() // Issue some obligation val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, @@ -221,6 +230,7 @@ class ObligationTests { /** Test generating a transaction to net two obligations of the same size, and therefore there are no outputs. */ @Test fun `generate close-out net transaction`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { @@ -232,6 +242,7 @@ class ObligationTests { /** Test generating a transaction to net two obligations of the different sizes, and confirm the balance is correct. */ @Test fun `generate close-out net transaction with remainder`() { + initialiseTestSerialization() val obligationAliceToBob = (2000000.DOLLARS `issued by` defaultIssuer).OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { @@ -246,6 +257,7 @@ class ObligationTests { /** Test generating a transaction to net two obligations of the same size, and therefore there are no outputs. */ @Test fun `generate payment net transaction`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { @@ -257,6 +269,7 @@ class ObligationTests { /** Test generating a transaction to two obligations, where one is bigger than the other and therefore there is a remainder. */ @Test fun `generate payment net transaction with remainder`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = (2000000.DOLLARS `issued by` defaultIssuer).OBLIGATION between Pair(BOB, ALICE) val tx = TransactionType.General.Builder(null).apply { @@ -271,6 +284,7 @@ class ObligationTests { /** Test generating a transaction to mark outputs as having defaulted. */ @Test fun `generate set lifecycle`() { + initialiseTestSerialization() // We don't actually verify the states, this is just here to make things look sensible val dueBefore = TEST_TX_TIME - 7.days @@ -309,6 +323,7 @@ class ObligationTests { /** Test generating a transaction to settle an obligation. */ @Test fun `generate settlement transaction`() { + initialiseTestSerialization() val cashTx = TransactionType.General.Builder(null).apply { Cash().generateIssue(this, 100.DOLLARS `issued by` defaultIssuer, MINI_CORP, DUMMY_NOTARY) }.toWireTransaction() @@ -857,6 +872,7 @@ class ObligationTests { @Test fun `summing balances due between parties`() { + initialiseTestSerialization() val simple: Map, Amount> = mapOf(Pair(Pair(ALICE, BOB), Amount(100000000, GBP))) val expected: Map = mapOf(Pair(ALICE, -100000000L), Pair(BOB, 100000000L)) val actual = sumAmountsDue(simple) diff --git a/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt b/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt index cc0ba55107..95f66909a8 100644 --- a/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt +++ b/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt @@ -10,6 +10,9 @@ import net.corda.contracts.testing.SignedTransactionGenerator import net.corda.core.flows.BroadcastTransactionFlow.NotifyTxRequest import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.testing.initialiseTestSerialization +import net.corda.testing.resetTestSerialization +import org.junit.After import org.junit.runner.RunWith import kotlin.test.assertEquals @@ -18,10 +21,16 @@ class BroadcastTransactionFlowTest { class NotifyTxRequestMessageGenerator : Generator(NotifyTxRequest::class.java) { override fun generate(random: SourceOfRandomness, status: GenerationStatus): NotifyTxRequest { + initialiseTestSerialization() return NotifyTxRequest(tx = SignedTransactionGenerator().generate(random, status)) } } + @After + fun teardown() { + resetTestSerialization() + } + @Property fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: NotifyTxRequest) { val serialized = message.serialize().bytes diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt index 8e5aaf8cfb..58dbce9e94 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -1,7 +1,6 @@ package net.corda.nodeapi -import com.esotericsoftware.kryo.pool.KryoPool -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.Try @@ -96,12 +95,12 @@ object RPCApi { val methodName: String, val arguments: List ) : ClientToServer() { - fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { MessageUtil.setJMSReplyTo(message, clientAddress) message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal) message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName) - message.bodyBuffer.writeBytes(arguments.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(arguments.serialize(context = context).bytes) } } @@ -119,14 +118,14 @@ object RPCApi { } companion object { - fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ClientToServer { + fun fromClientMessage(context: SerializationContext, message: ClientMessage): ClientToServer { val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] return when (tag) { RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest( clientAddress = MessageUtil.getJMSReplyTo(message), id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)), methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME), - arguments = message.getBodyAsByteArray().deserialize(kryoPool) + arguments = message.getBodyAsByteArray().deserialize(context = context) ) RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> { val ids = ArrayList() @@ -148,16 +147,16 @@ object RPCApi { OBSERVATION } - abstract fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) + abstract fun writeToClientMessage(context: SerializationContext, message: ClientMessage) data class RpcReply( val id: RpcRequestId, val result: Try ) : ServerToClient() { - override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal) message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) - message.bodyBuffer.writeBytes(result.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(result.serialize(context = context).bytes) } } @@ -165,31 +164,31 @@ object RPCApi { val id: ObservableId, val content: Notification ) : ServerToClient() { - override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal) message.putLongProperty(OBSERVABLE_ID_FIELD_NAME, id.toLong) - message.bodyBuffer.writeBytes(content.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(content.serialize(context = context).bytes) } } companion object { - fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ServerToClient { + fun fromClientMessage(context: SerializationContext, message: ClientMessage): ServerToClient { val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] return when (tag) { RPCApi.ServerToClient.Tag.RPC_REPLY -> { val id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)) - val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong) RpcReply( id = id, - result = message.getBodyAsByteArray().deserialize(poolWithIdContext) + result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) ) } RPCApi.ServerToClient.Tag.OBSERVATION -> { val id = ObservableId(message.getLongProperty(OBSERVABLE_ID_FIELD_NAME)) - val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong) Observation( id = id, - content = message.getBodyAsByteArray().deserialize(poolWithIdContext) + content = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) ) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt index 0b5236f68f..796d0e38f3 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt @@ -46,7 +46,7 @@ class PermissionException(msg: String) : RuntimeException(msg) // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // because we can see everything we're using in one place. -class RPCKryo(observableSerializer: Serializer>) : CordaKryo(makeStandardClassResolver()) { +class RPCKryo(observableSerializer: Serializer>, whitelist: ClassWhitelist) : CordaKryo(CordaClassResolver(whitelist)) { init { DefaultKryoCustomizer.customize(this) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..1dbd5ce947 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/serialization/SerializationScheme.kt @@ -0,0 +1,263 @@ +package net.corda.nodeapi.serialization + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.io.serialization.kryo.KryoSerializer +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoException +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import io.requery.util.CloseableIterator +import net.corda.core.internal.LazyPool +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes +import java.io.ByteArrayOutputStream +import java.io.NotSerializableException +import java.util.* +import java.util.concurrent.ConcurrentHashMap + +object NotSupportedSeralizationScheme : SerializationScheme { + private fun doThrow(): Nothing = throw UnsupportedOperationException("Serialization scheme not supported.") + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean = doThrow() + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T = doThrow() + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes = doThrow() +} + +data class SerializationContextImpl(override val preferedSerializationVersion: ByteSequence, + override val deserializationClassLoader: ClassLoader, + override val whitelist: ClassWhitelist, + override val properties: Map, + override val objectReferencesEnabled: Boolean, + override val useCase: SerializationContext.UseCase) : SerializationContext { + + override fun withProperty(property: Any, value: Any): SerializationContext { + return copy(properties = properties + (property to value)) + } + + override fun withoutReferences(): SerializationContext { + return copy(objectReferencesEnabled = false) + } + + override fun withClassLoader(classLoader: ClassLoader): SerializationContext { + return copy(deserializationClassLoader = classLoader) + } + + override fun withWhitelisted(clazz: Class<*>): SerializationContext { + return copy(whitelist = object : ClassWhitelist { + override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name + }) + } +} + +open class SerializationFactoryImpl : SerializationFactory { + private val creator: List = Exception().stackTrace.asList() + + private val registeredSchemes: MutableCollection = Collections.synchronizedCollection(mutableListOf()) + + // TODO: This is read-mostly. Probably a faster implementation to be found. + private val schemes: ConcurrentHashMap, SerializationScheme> = ConcurrentHashMap() + + private fun schemeFor(byteSequence: ByteSequence, target: SerializationContext.UseCase): SerializationScheme { + // truncate sequence to 8 bytes + return schemes.computeIfAbsent(byteSequence.take(8).copy() to target) { + for (scheme in registeredSchemes) { + if (scheme.canDeserializeVersion(it.first, it.second)) { + return@computeIfAbsent scheme + } + } + NotSupportedSeralizationScheme + } + } + + @Throws(NotSerializableException::class) + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T = schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context) + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + return schemeFor(context.preferedSerializationVersion, context.useCase).serialize(obj, context) + } + + fun registerScheme(scheme: SerializationScheme) { + check(schemes.isEmpty()) { "All serialization schemes must be registered before any scheme is used." } + registeredSchemes += scheme + } + + val alreadyRegisteredSchemes: Collection get() = Collections.unmodifiableCollection(registeredSchemes) + + override fun toString(): String { + return "${this.javaClass.name} registeredSchemes=$registeredSchemes ${creator.joinToString("\n")}" + } + + override fun equals(other: Any?): Boolean { + return other is SerializationFactoryImpl && + other.registeredSchemes == this.registeredSchemes + } + + override fun hashCode(): Int = registeredSchemes.hashCode() +} + +private object AutoCloseableSerialisationDetector : Serializer() { + override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { + val message = if (closeable is CloseableIterator<*>) { + "A live Iterator pointing to the database has been detected during flow checkpointing. This may be due " + + "to a Vault query - move it into a private method." + } else { + "${closeable.javaClass.name}, which is a closeable resource, has been detected during flow checkpointing. " + + "Restoring such resources across node restarts is not supported. Make sure code accessing it is " + + "confined to a private method or the reference is nulled out." + } + throw UnsupportedOperationException(message) + } + + override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") +} + +abstract class AbstractKryoSerializationScheme : SerializationScheme { + private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() + + protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool + protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool + + private fun getPool(context: SerializationContext): KryoPool { + return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { + when (context.useCase) { + SerializationContext.UseCase.Checkpoint -> + KryoPool.Builder { + val serializer = Fiber.getFiberSerializer(false) as KryoSerializer + val classResolver = makeNoWhitelistClassResolver().apply { setKryo(serializer.kryo) } + // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that + val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } + serializer.kryo.apply { + field.set(this, classResolver) + DefaultKryoCustomizer.customize(this) + addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) + classLoader = it.second + } + }.build() + SerializationContext.UseCase.RPCClient -> + rpcClientKryoPool(context) + SerializationContext.UseCase.RPCServer -> + rpcServerKryoPool(context) + else -> + KryoPool.Builder { + DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context.whitelist))).apply { classLoader = it.second } + }.build() + } + } + } + + private fun withContext(kryo: Kryo, context: SerializationContext, block: (Kryo) -> T): T { + kryo.context.ensureCapacity(context.properties.size) + context.properties.forEach { kryo.context.put(it.key, it.value) } + try { + return block(kryo) + } finally { + kryo.context.clear() + } + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val pool = getPool(context) + Input(byteSequence.bytes, byteSequence.offset, byteSequence.size).use { input -> + val header = OpaqueBytes(input.readBytes(8)) + if (header != KryoHeaderV0_1) { + throw KryoException("Serialized bytes header does not match expected format.") + } + return pool.run { kryo -> + withContext(kryo, context) { + @Suppress("UNCHECKED_CAST") + if (context.objectReferencesEnabled) { + kryo.readClassAndObject(input) as T + } else { + kryo.withoutReferences { kryo.readClassAndObject(input) as T } + } + } + } + } + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val pool = getPool(context) + return pool.run { kryo -> + withContext(kryo, context) { + serializeOutputStreamPool.run { stream -> + serializeBufferPool.run { buffer -> + Output(buffer).use { + it.outputStream = stream + it.writeBytes(KryoHeaderV0_1.bytes) + if (context.objectReferencesEnabled) { + kryo.writeClassAndObject(it, obj) + } else { + kryo.withoutReferences { kryo.writeClassAndObject(it, obj) } + } + } + SerializedBytes(stream.toByteArray()) + } + } + } + } + } +} + +private val serializeBufferPool = LazyPool( + newInstance = { ByteArray(64 * 1024) } +) +private val serializeOutputStreamPool = LazyPool( + clear = ByteArrayOutputStream::reset, + shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large + newInstance = { ByteArrayOutputStream(64 * 1024) } +) + +// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB +val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray(Charsets.UTF_8)) + + +val KRYO_P2P_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.P2P) +val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.RPCServer) +val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.RPCClient) +val KRYO_STORAGE_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + AllButBlacklisted, + emptyMap(), + true, + SerializationContext.UseCase.Storage) +val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + QuasarWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.Checkpoint) + +object QuasarWhitelist : ClassWhitelist { + override fun hasListed(type: Class<*>): Boolean = true +} + +interface SerializationScheme { + // byteSequence expected to just be the 8 bytes necessary for versioning + fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean + + @Throws(NotSerializableException::class) + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T + + @Throws(NotSerializableException::class) + fun serialize(obj: T, context: SerializationContext): SerializedBytes +} \ No newline at end of file diff --git a/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt b/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt index f53ad70a38..8b4a4d8adb 100644 --- a/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt +++ b/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt @@ -21,10 +21,7 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.transactions.LedgerTransaction import net.corda.node.services.vault.schemas.requery.* -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY +import net.corda.testing.* import net.corda.testing.contracts.DummyContract import org.h2.jdbcx.JdbcDataSource import org.junit.After @@ -40,7 +37,7 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue -class VaultSchemaTest { +class VaultSchemaTest : TestDependencyInjectionBase() { var instance: KotlinEntityDataStore? = null val data: KotlinEntityDataStore get() = instance!! diff --git a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt index 3048c19417..545fa7e77b 100644 --- a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt @@ -9,13 +9,13 @@ import net.corda.core.messaging.startFlow import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.testing.ALICE -import net.corda.testing.driver.driver import net.corda.node.internal.NodeStartup import net.corda.node.services.startFlowPermission import net.corda.nodeapi.User import net.corda.testing.driver.ListenProcessDeathException import net.corda.testing.driver.NetworkMapStartStrategy import net.corda.testing.ProjectStructure.projectRootDir +import net.corda.testing.driver.driver import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test diff --git a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt index 00d5cc66ac..029eb68cff 100644 --- a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt +++ b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt @@ -10,7 +10,10 @@ import com.google.common.collect.testing.features.MapFeature import com.google.common.collect.testing.features.SetFeature import com.google.common.collect.testing.testers.* import junit.framework.TestSuite +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.initialiseTestSerialization import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.resetTestSerialization import org.assertj.core.api.Assertions.assertThat import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.transactions.TransactionManager @@ -42,6 +45,7 @@ class JDBCHashMapTestSuite { @JvmStatic @BeforeClass fun before() { + initialiseTestSerialization() database = configureDatabase(makeTestDataSourceProperties()) setUpDatabaseTx() loadOnInitFalseMap = JDBCHashMap("test_map_false", loadOnInit = false) @@ -57,6 +61,7 @@ class JDBCHashMapTestSuite { fun after() { closeDatabaseTx() database.close() + resetTestSerialization() } @JvmStatic @@ -198,7 +203,7 @@ class JDBCHashMapTestSuite { * * If the Map reloads, then so will the Set as it just delegates. */ - class MapCanBeReloaded { + class MapCanBeReloaded : TestDependencyInjectionBase() { private val ops = listOf(Triple(AddOrRemove.ADD, "A", "1"), Triple(AddOrRemove.ADD, "B", "2"), Triple(AddOrRemove.ADD, "C", "3"), @@ -235,7 +240,6 @@ class JDBCHashMapTestSuite { database.close() } - @Test fun `fill map and check content after reconstruction`() { database.transaction { diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt index 9cae274260..87cc2d2363 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt @@ -151,7 +151,7 @@ abstract class MQSecurityTest : NodeBasedTest() { } fun loginToRPC(target: NetworkHostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): CordaRPCOps { - return CordaRPCClient(target, sslConfiguration).start(rpcUser.username, rpcUser.password).proxy + return CordaRPCClient(target, sslConfiguration, initialiseSerialization = false).start(rpcUser.username, rpcUser.password).proxy } fun loginToRPCAndGetClientQueue(): String { diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index bdd44c655e..35183e4fdc 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -4,12 +4,15 @@ import com.codahale.metrics.JmxReporter import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture -import net.corda.core.* +import net.corda.core.flatMap import net.corda.core.messaging.RPCOps import net.corda.core.node.ServiceHub import net.corda.core.node.services.ServiceInfo +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.thenMatch import net.corda.core.utilities.* import net.corda.node.VersionInfo +import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.NodeClock import net.corda.node.services.RPCUserService import net.corda.node.services.RPCUserServiceImpl @@ -29,6 +32,7 @@ import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.addShutdownHook +import net.corda.nodeapi.serialization.* import org.apache.activemq.artemis.api.core.ActiveMQNotConnectedException import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.client.ActiveMQClient @@ -54,7 +58,8 @@ import kotlin.system.exitProcess open class Node(override val configuration: FullNodeConfiguration, advertisedServices: Set, val versionInfo: VersionInfo, - clock: Clock = NodeClock()) : AbstractNode(configuration, advertisedServices, clock) { + clock: Clock = NodeClock(), + val initialiseSerialization: Boolean = true) : AbstractNode(configuration, advertisedServices, clock) { companion object { private val logger = loggerFor() var renderBasicInfoToConsole = true @@ -290,6 +295,9 @@ open class Node(override val configuration: FullNodeConfiguration, val startupComplete: ListenableFuture = SettableFuture.create() override fun start(): Node { + if (initialiseSerialization) { + initialiseSerialization() + } super.start() networkMapRegistrationFuture.thenMatch({ @@ -321,6 +329,16 @@ open class Node(override val configuration: FullNodeConfiguration, return this } + private fun initialiseSerialization() { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoServerSerializationScheme()) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT + SerializationDefaults.STORAGE_CONTEXT = KRYO_STORAGE_CONTEXT + SerializationDefaults.CHECKPOINT_CONTEXT = KRYO_CHECKPOINT_CONTEXT + } + /** Starts a blocking event loop for message dispatch. */ fun run() { (network as NodeMessagingClient).run(messageBroker!!.serverControl) diff --git a/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..9fdb211720 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt @@ -0,0 +1,26 @@ +package net.corda.node.serialization + +import com.esotericsoftware.kryo.pool.KryoPool +import net.corda.core.serialization.DefaultKryoCustomizer +import net.corda.core.serialization.SerializationContext +import net.corda.core.utilities.ByteSequence +import net.corda.node.services.messaging.RpcServerObservableSerializer +import net.corda.nodeapi.RPCKryo +import net.corda.nodeapi.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.serialization.KryoHeaderV0_1 + +class KryoServerSerializationScheme : AbstractKryoSerializationScheme() { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence.equals(KryoHeaderV0_1) && target != SerializationContext.UseCase.RPCClient + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } + + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + return KryoPool.Builder { + DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } + }.build() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index baf8c623d3..58bf226723 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -154,7 +154,8 @@ fun MessagingService.onNext(topic: String, sessionId: Long): Listenabl val messageFuture = SettableFuture.create() runOnNextMessage(topic, sessionId) { message -> messageFuture.catch { - message.data.deserialize() + @Suppress("UNCHECKED_CAST") + message.data.deserialize() as M } } return messageFuture diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt index afd2df6204..81f9a47701 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt @@ -1,18 +1,21 @@ package net.corda.node.services.messaging import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.* +import net.corda.core.ThreadBox +import net.corda.core.andForget import net.corda.core.crypto.random63BitValue +import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.TransactionVerifierService -import net.corda.core.utilities.opaque +import net.corda.core.thenMatch import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.sequence import net.corda.core.utilities.trace import net.corda.node.VersionInfo import net.corda.node.services.RPCUserService @@ -346,7 +349,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, private val message: ClientMessage) : ReceivedMessage { override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } } override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) - override fun toString() = "${topicSession.topic}#${data.opaque()}" + override fun toString() = "${topicSession.topic}#${data.sequence()}" } private fun deliver(msg: ReceivedMessage): Boolean { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index cb42b35110..2fb64c02f4 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -4,7 +4,6 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.cache.Cache import com.google.common.cache.CacheBuilder import com.google.common.cache.RemovalListener @@ -17,7 +16,8 @@ import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle import net.corda.core.messaging.RPCOps import net.corda.core.utilities.seconds -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT import net.corda.core.utilities.* import net.corda.node.services.RPCUserService import net.corda.nodeapi.* @@ -81,7 +81,6 @@ class RPCServer( ) { private companion object { val log = loggerFor() - val kryoPool = KryoPool.Builder { RPCKryo(RpcServerObservableSerializer) }.build() } private enum class State { UNSTARTED, @@ -258,7 +257,7 @@ class RPCServer( private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) { lifeCycle.requireState(State.STARTED) - val clientToServer = RPCApi.ClientToServer.fromClientMessage(kryoPool, artemisMessage) + val clientToServer = RPCApi.ClientToServer.fromClientMessage(RPC_SERVER_CONTEXT, artemisMessage) log.debug { "-> RPC -> $clientToServer" } when (clientToServer) { is RPCApi.ClientToServer.RpcRequest -> { @@ -302,8 +301,7 @@ class RPCServer( clientAddress, serverControl!!, sessionAndProducerPool, - observationSendExecutor!!, - kryoPool + observationSendExecutor!! ) val buffered = bufferIfQueueNotBound(clientAddress, reply, observableContext) @@ -385,19 +383,19 @@ class ObservableContext( val clientAddress: SimpleString, val serverControl: ActiveMQServerControl, val sessionAndProducerPool: LazyStickyPool, - val observationSendExecutor: ExecutorService, - kryoPool: KryoPool + val observationSendExecutor: ExecutorService ) { private companion object { val log = loggerFor() } - private val kryoPoolWithObservableContext = RpcServerObservableSerializer.createPoolWithContext(kryoPool, this) + private val serializationContextWithObservableContext = RpcServerObservableSerializer.createContext(this) + fun sendMessage(serverToClient: RPCApi.ServerToClient) { try { sessionAndProducerPool.run(rpcRequestId) { val artemisMessage = it.session.createMessage(false) - serverToClient.writeToClientMessage(kryoPoolWithObservableContext, artemisMessage) + serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) it.producer.send(clientAddress, artemisMessage) log.debug("<- RPC <- $serverToClient") } @@ -408,12 +406,12 @@ class ObservableContext( } } -private object RpcServerObservableSerializer : Serializer>() { +object RpcServerObservableSerializer : Serializer>() { private object RpcObservableContextKey private val log = loggerFor() - fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { - return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + fun createContext(observableContext: ObservableContext): SerializationContext { + return RPC_SERVER_CONTEXT.withProperty(RpcServerObservableSerializer.RpcObservableContextKey, observableContext) } override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index 89db282424..1d4ccde166 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,10 +1,10 @@ package net.corda.node.services.persistence import net.corda.core.crypto.SecureHash +import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.utilities.* @@ -39,7 +39,7 @@ class DBCheckpointStorage : CheckpointStorage { private val checkpointStorage = synchronizedMap(CheckpointMap()) override fun addCheckpoint(checkpoint: Checkpoint) { - checkpointStorage.put(checkpoint.id, checkpoint.serialize(storageKryo(), true)) + checkpointStorage.put(checkpoint.id, checkpoint.serialize(context = CHECKPOINT_CONTEXT)) } override fun removeCheckpoint(checkpoint: Checkpoint) { @@ -49,7 +49,7 @@ class DBCheckpointStorage : CheckpointStorage { override fun forEach(block: (Checkpoint) -> Boolean) { synchronized(checkpointStorage) { for (checkpoint in checkpointStorage.values) { - if (!block(checkpoint.deserialize())) { + if (!block(checkpoint.deserialize(context = CHECKPOINT_CONTEXT))) { break } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 1fa7885055..3104a3b9b2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -2,20 +2,12 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.io.serialization.kryo.KryoSerializer import co.paralleluniverse.strands.Strand import com.codahale.metrics.Gauge -import com.esotericsoftware.kryo.ClassResolver -import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.collect.HashMultimap import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors -import io.requery.util.CloseableIterator import net.corda.core.ThreadBox import net.corda.core.bufferUntilSubscribed import net.corda.core.crypto.SecureHash @@ -25,9 +17,10 @@ import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party -import net.corda.core.internal.declaredField import net.corda.core.messaging.DataFeed import net.corda.core.serialization.* +import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY import net.corda.core.then import net.corda.core.utilities.Try import net.corda.core.utilities.debug @@ -85,34 +78,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) - private val quasarKryoPool = KryoPool.Builder { - val serializer = Fiber.getFiberSerializer(false) as KryoSerializer - val classResolver = makeNoWhitelistClassResolver().apply { setKryo(serializer.kryo) } - serializer.kryo.apply { - // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that - declaredField(Kryo::class, "classResolver").value = classResolver - DefaultKryoCustomizer.customize(this) - addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) - } - }.build() - - // TODO Move this into the blacklist and upgrade the blacklist to allow custom messages - private object AutoCloseableSerialisationDetector : Serializer() { - override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { - val message = if (closeable is CloseableIterator<*>) { - "A live Iterator pointing to the database has been detected during flow checkpointing. This may be due " + - "to a Vault query - move it into a private method." - } else { - "${closeable.javaClass.name}, which is a closeable resource, has been detected during flow checkpointing. " + - "Restoring such resources across node restarts is not supported. Make sure code accessing it is " + - "confined to a private method or the reference is nulled out." - } - throw UnsupportedOperationException(message) - } - - override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") - } - companion object { private val logger = loggerFor() internal val sessionTopic = TopicSession("platform.session") @@ -173,7 +138,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, internal val tokenizableServices = ArrayList() // Context for tokenized services in checkpoints private val serializationContext by lazy { - SerializeAsTokenContext(tokenizableServices, quasarKryoPool, serviceHub) + SerializeAsTokenContext(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) } /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ @@ -410,22 +375,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { - return quasarKryoPool.run { kryo -> - // add the map of tokens -> tokenizedServices to the kyro context - kryo.withSerializationContext(serializationContext) { - fiber.serialize(kryo) - } - } + return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) } private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { return try { - quasarKryoPool.run { kryo -> - // put the map of token -> tokenized into the kryo context - kryo.withSerializationContext(serializationContext) { - checkpoint.serializedFiber.deserialize(kryo) - }.apply { fromCheckpoint = true } - } + checkpoint.serializedFiber.deserialize>(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { fromCheckpoint = true } } catch (t: Throwable) { logger.error("Encountered unrestorable checkpoint!", t) null diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt index fe91e3f579..929312f753 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt @@ -13,9 +13,9 @@ import net.corda.core.node.services.VaultQueryException import net.corda.core.node.services.VaultQueryService import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize -import net.corda.core.serialization.storageKryo import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor import net.corda.node.services.database.HibernateConfiguration @@ -96,7 +96,7 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, return@forEachIndexed val vaultState = result[0] as VaultSchemaV1.VaultStates val stateRef = StateRef(SecureHash.parse(vaultState.stateRef!!.txId!!), vaultState.stateRef!!.index!!) - val state = vaultState.contractState.deserialize>(storageKryo()) + val state = vaultState.contractState.deserialize>(context = STORAGE_CONTEXT) statesMeta.add(Vault.StateMetadata(stateRef, vaultState.contractStateClassName, vaultState.recordedTime, vaultState.consumedTime, vaultState.stateStatus, vaultState.notaryName, vaultState.notaryKey, vaultState.lockId, vaultState.lockUpdateTime)) statesAndRefs.add(StateAndRef(state, stateRef)) } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index cd049ace71..2e436f1a95 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -26,10 +26,10 @@ import net.corda.core.node.services.StatesNotAvailableException import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultService import net.corda.core.node.services.unconsumedStates +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.tee import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction @@ -95,7 +95,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P index = it.key.index stateStatus = Vault.StateStatus.UNCONSUMED contractStateClassName = it.value.state.data.javaClass.name - contractState = it.value.state.serialize(storageKryo()).bytes + contractState = it.value.state.serialize(context = STORAGE_CONTEXT).bytes notaryName = it.value.state.notary.name.toString() notaryKey = it.value.state.notary.owningKey.toBase58String() recordedTime = services.clock.instant() @@ -198,7 +198,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P Sequence { iterator } .map { it -> val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) + val state = it.contractState.deserialize>(context = STORAGE_CONTEXT) Vault.StateMetadata(stateRef, it.contractStateClassName, it.recordedTime, it.consumedTime, it.stateStatus, it.notaryName, it.notaryKey, it.lockId, it.lockUpdateTime) StateAndRef(state, stateRef) } @@ -217,7 +217,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P .and(VaultSchema.VaultStates::index eq it.index) result.get()?.each { val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) + val state = it.contractState.deserialize>(context = STORAGE_CONTEXT) results += StateAndRef(state, stateRef) } } @@ -380,7 +380,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P val txHash = SecureHash.parse(rs.getString(1)) val index = rs.getInt(2) val stateRef = StateRef(txHash, index) - val state = rs.getBytes(3).deserialize>(storageKryo()) + val state = rs.getBytes(3).deserialize>(context = STORAGE_CONTEXT) val pennies = rs.getLong(4) totalPennies = rs.getLong(5) val rowLockId = rs.getString(6) @@ -435,7 +435,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P query.get() .map { it -> val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) + val state = it.contractState.deserialize>(context = STORAGE_CONTEXT) StateAndRef(state, stateRef) }.toList() } @@ -480,7 +480,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P result.get().forEach { val txHash = SecureHash.parse(it.txId) val index = it.index - val state = it.contractState.deserialize>(storageKryo()) + val state = it.contractState.deserialize>(context = STORAGE_CONTEXT) consumedStates.add(StateAndRef(state, StateRef(txHash, index))) } } diff --git a/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt b/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt index 96147f2739..5b193ec2e8 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt @@ -1,9 +1,9 @@ package net.corda.node.utilities +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace import org.jetbrains.exposed.sql.* @@ -65,17 +65,18 @@ fun bytesToBlob(value: SerializedBytes<*>, finalizables: MutableList<() -> Unit> return blob } -fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(storageKryo(), true), finalizables) +fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(context = STORAGE_CONTEXT), finalizables) fun bytesFromBlob(blob: Blob): SerializedBytes { try { - return SerializedBytes(blob.getBytes(0, blob.length().toInt()), true) + return SerializedBytes(blob.getBytes(0, blob.length().toInt())) } finally { blob.free() } } -fun deserializeFromBlob(blob: Blob): T = bytesFromBlob(blob).deserialize() +@Suppress("UNCHECKED_CAST") +fun deserializeFromBlob(blob: Blob): T = bytesFromBlob(blob).deserialize(context = STORAGE_CONTEXT) as T /** * A convenient JDBC table backed hash set with iteration order based on insertion order. diff --git a/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt b/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt index 43daa96cbc..fc87226b3d 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt @@ -37,7 +37,6 @@ object ServiceIdentityGenerator { keyPairs.zip(dirs) { keyPair, dir -> Files.createDirectories(dir) Files.write(dir.resolve(compositeKeyFile), notaryKey.encoded) - // Use storageKryo as our whitelist is not available in the gradle build environment: Files.write(dir.resolve(privateKeyFile), keyPair.private.encoded) Files.write(dir.resolve(publicKeyFile), keyPair.public.encoded) } diff --git a/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java b/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java index 5583482b51..0d5fff054a 100644 --- a/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java +++ b/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java @@ -1,47 +1,60 @@ package net.corda.node.services.vault; -import com.google.common.collect.*; -import kotlin.*; -import net.corda.contracts.*; -import net.corda.contracts.asset.*; +import com.google.common.collect.ImmutableSet; +import net.corda.contracts.DealState; +import net.corda.contracts.asset.Cash; import net.corda.core.contracts.*; -import net.corda.core.crypto.*; -import net.corda.core.identity.*; -import net.corda.core.messaging.*; -import net.corda.core.node.services.*; +import net.corda.core.crypto.EncodingUtils; +import net.corda.core.crypto.SecureHash; +import net.corda.core.identity.AbstractParty; +import net.corda.core.messaging.DataFeed; +import net.corda.core.node.services.Vault; +import net.corda.core.node.services.VaultQueryException; +import net.corda.core.node.services.VaultQueryService; +import net.corda.core.node.services.VaultService; import net.corda.core.node.services.vault.*; -import net.corda.core.node.services.vault.QueryCriteria.*; -import net.corda.testing.contracts.DummyLinearContract; -import net.corda.core.schemas.*; -import net.corda.core.transactions.*; -import net.corda.core.utilities.*; +import net.corda.core.node.services.vault.QueryCriteria.LinearStateQueryCriteria; +import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria; +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria; +import net.corda.core.schemas.MappedSchema; +import net.corda.core.transactions.SignedTransaction; +import net.corda.core.transactions.WireTransaction; +import net.corda.core.utilities.OpaqueBytes; +import net.corda.node.services.database.HibernateConfiguration; +import net.corda.node.services.schema.NodeSchemaService; import net.corda.node.utilities.CordaPersistence; -import net.corda.node.services.database.*; -import net.corda.node.services.schema.*; -import net.corda.schemas.*; -import net.corda.testing.*; -import net.corda.testing.contracts.*; -import net.corda.testing.node.*; +import net.corda.schemas.CashSchemaV1; +import net.corda.testing.TestConstants; +import net.corda.testing.TestDependencyInjectionBase; +import net.corda.testing.contracts.DummyLinearContract; +import net.corda.testing.contracts.VaultFiller; +import net.corda.testing.node.MockServices; import net.corda.testing.schemas.DummyLinearStateSchemaV1; -import org.jetbrains.annotations.*; -import org.junit.*; +import org.jetbrains.annotations.NotNull; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; import rx.Observable; -import java.io.*; -import java.lang.reflect.*; +import java.io.IOException; +import java.lang.reflect.Field; import java.util.*; -import java.util.stream.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; -import static net.corda.contracts.asset.CashKt.*; -import static net.corda.core.contracts.ContractsDSL.*; -import static net.corda.core.node.services.vault.QueryCriteriaUtils.*; +import static net.corda.contracts.asset.CashKt.getDUMMY_CASH_ISSUER; +import static net.corda.contracts.asset.CashKt.getDUMMY_CASH_ISSUER_KEY; +import static net.corda.core.contracts.ContractsDSL.USD; +import static net.corda.core.node.services.vault.QueryCriteriaUtils.DEFAULT_PAGE_NUM; +import static net.corda.core.node.services.vault.QueryCriteriaUtils.MAX_PAGE_SIZE; +import static net.corda.core.utilities.ByteArrays.toHexString; import static net.corda.node.utilities.CordaPersistenceKt.configureDatabase; import static net.corda.testing.CoreTestUtils.*; -import static net.corda.testing.node.MockServicesKt.*; -import static net.corda.core.utilities.ByteArrays.toHexString; -import static org.assertj.core.api.Assertions.*; +import static net.corda.testing.node.MockServicesKt.makeTestDataSourceProperties; +import static org.assertj.core.api.Assertions.assertThat; -public class VaultQueryJavaTests { +public class VaultQueryJavaTests extends TestDependencyInjectionBase { private MockServices services; private VaultService vaultSvc; diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index ad09bb67a8..7dbce0a062 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -8,6 +8,8 @@ import net.corda.node.services.messaging.createMessage import net.corda.node.services.network.NetworkMapService import net.corda.testing.node.MockNetwork import org.junit.After +import net.corda.testing.resetTestSerialization +import org.junit.Before import org.junit.Test import java.util.* import kotlin.test.assertEquals @@ -15,11 +17,20 @@ import kotlin.test.assertFails import kotlin.test.assertTrue class InMemoryMessagingTests { - val mockNet = MockNetwork() + lateinit var mockNet: MockNetwork + + @Before + fun setUp() { + mockNet = MockNetwork() + } @After - fun cleanUp() { - mockNet.stopNodes() + fun tearDown() { + if (mockNet.nodes.isNotEmpty()) { + mockNet.stopNodes() + } else { + resetTestSerialization() + } } @Test diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index e261576ccd..97664e7f0a 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -76,7 +76,6 @@ class TwoPartyTradeFlowTests { @Before fun before() { - mockNet = MockNetwork(false) LogHelper.setLevel("platform.trade", "core.contract.TransactionGroup", "recordingmap") } @@ -93,7 +92,7 @@ class TwoPartyTradeFlowTests { // allow interruption half way through. mockNet = MockNetwork(false, true) - ledger { + ledger(initialiseSerialization = false) { val basketOfNodes = mockNet.createSomeNodes(3) val notaryNode = basketOfNodes.notaryNode val aliceNode = basketOfNodes.partyNodes[0] @@ -140,7 +139,7 @@ class TwoPartyTradeFlowTests { fun `trade cash for commercial paper fails using soft locking`() { mockNet = MockNetwork(false, true) - ledger { + ledger(initialiseSerialization = false) { val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = mockNet.createPartyNode(notaryNode.network.myAddress, ALICE.name) val bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) @@ -191,7 +190,8 @@ class TwoPartyTradeFlowTests { @Test fun `shutdown and restore`() { - ledger { + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = mockNet.createPartyNode(notaryNode.network.myAddress, ALICE.name) var bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) @@ -313,13 +313,15 @@ class TwoPartyTradeFlowTests { @Test fun `check dependencies of sale asset are resolved`() { + mockNet = MockNetwork(false) + val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = makeNodeWithTracking(notaryNode.network.myAddress, ALICE.name) val bobNode = makeNodeWithTracking(notaryNode.network.myAddress, BOB.name) val bankNode = makeNodeWithTracking(notaryNode.network.myAddress, BOC.name) val issuer = bankNode.info.legalIdentity.ref(1, 2, 3) - ledger(aliceNode.services) { + ledger(aliceNode.services, initialiseSerialization = false) { // Insert a prospectus type attachment into the commercial paper transaction. val stream = ByteArrayOutputStream() @@ -412,13 +414,15 @@ class TwoPartyTradeFlowTests { @Test fun `track works`() { + mockNet = MockNetwork(false) + val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = makeNodeWithTracking(notaryNode.network.myAddress, ALICE.name) val bobNode = makeNodeWithTracking(notaryNode.network.myAddress, BOB.name) val bankNode = makeNodeWithTracking(notaryNode.network.myAddress, BOC.name) val issuer = bankNode.info.legalIdentity.ref(1, 2, 3) - ledger(aliceNode.services) { + ledger(aliceNode.services, initialiseSerialization = false) { // Insert a prospectus type attachment into the commercial paper transaction. val stream = ByteArrayOutputStream() @@ -487,14 +491,16 @@ class TwoPartyTradeFlowTests { @Test fun `dependency with error on buyer side`() { - ledger { + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { runWithError(true, false, "at least one asset input") } } @Test fun `dependency with error on seller side`() { - ledger { + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { runWithError(false, true, "Issuances must have a time-window") } } diff --git a/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt b/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt index cfd8a53088..7aeadcf487 100644 --- a/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt @@ -34,7 +34,7 @@ class NotaryChangeTests { lateinit var clientNodeB: MockNetwork.MockNode @Before - fun setup() { + fun setUp() { mockNet = MockNetwork() oldNotaryNode = mockNet.createNode( legalName = DUMMY_NOTARY.name, diff --git a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt index e26830be2f..fe17999d5c 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt @@ -3,39 +3,32 @@ package net.corda.node.services.database import net.corda.contracts.asset.Cash import net.corda.contracts.asset.DUMMY_CASH_ISSUER import net.corda.contracts.asset.DummyFungibleContract -import net.corda.testing.contracts.consumeCash -import net.corda.testing.contracts.fillWithSomeTestCash -import net.corda.testing.contracts.fillWithSomeTestDeals -import net.corda.testing.contracts.fillWithSomeTestLinearStates import net.corda.core.contracts.* import net.corda.core.crypto.toBase58String import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultService +import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.PersistentStateRef -import net.corda.testing.schemas.DummyLinearStateSchemaV1 -import net.corda.testing.schemas.DummyLinearStateSchemaV2 -import net.corda.core.serialization.storageKryo +import net.corda.core.serialization.deserialize import net.corda.core.transactions.SignedTransaction -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.BOB_KEY -import net.corda.testing.DUMMY_NOTARY import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.vault.NodeVaultService -import net.corda.core.schemas.CommonSchemaV1 -import net.corda.core.serialization.deserialize import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase import net.corda.schemas.CashSchemaV1 import net.corda.schemas.SampleCashSchemaV2 import net.corda.schemas.SampleCashSchemaV3 -import net.corda.testing.BOB_PUBKEY -import net.corda.testing.BOC -import net.corda.testing.BOC_KEY +import net.corda.testing.* +import net.corda.testing.contracts.consumeCash +import net.corda.testing.contracts.fillWithSomeTestCash +import net.corda.testing.contracts.fillWithSomeTestDeals +import net.corda.testing.contracts.fillWithSomeTestLinearStates import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.schemas.DummyLinearStateSchemaV1 +import net.corda.testing.schemas.DummyLinearStateSchemaV2 import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThat import org.hibernate.SessionFactory @@ -48,7 +41,7 @@ import javax.persistence.EntityManager import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder -class HibernateConfigurationTest { +class HibernateConfigurationTest : TestDependencyInjectionBase() { lateinit var services: MockServices lateinit var database: CordaPersistence @@ -655,7 +648,7 @@ class HibernateConfigurationTest { val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val contractState = it.contractState.deserialize>(storageKryo()) + val contractState = it.contractState.deserialize>() val cashState = contractState.data as Cash.State println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") } @@ -739,7 +732,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val contractState = it.contractState.deserialize>(storageKryo()) + val contractState = it.contractState.deserialize>() val cashState = contractState.data as Cash.State println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") } diff --git a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt index a670cf2a4d..5153d9189e 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt @@ -5,7 +5,6 @@ import io.requery.kotlin.eq import io.requery.sql.KotlinEntityDataStore import net.corda.core.contracts.StateRef import net.corda.core.contracts.TransactionType -import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.SecureHash import net.corda.core.crypto.testing.NullPublicKey @@ -13,11 +12,8 @@ import net.corda.core.crypto.toBase58String import net.corda.core.identity.AnonymousParty import net.corda.core.node.services.Vault import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_PUBKEY_1 import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.vault.schemas.requery.Models import net.corda.node.services.vault.schemas.requery.VaultCashBalancesEntity @@ -25,6 +21,10 @@ import net.corda.node.services.vault.schemas.requery.VaultSchema import net.corda.node.services.vault.schemas.requery.VaultStatesEntity import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.DUMMY_PUBKEY_1 +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions import org.junit.After @@ -35,7 +35,7 @@ import org.junit.Test import java.time.Instant import java.util.* -class RequeryConfigurationTest { +class RequeryConfigurationTest : TestDependencyInjectionBase() { lateinit var database: CordaPersistence lateinit var transactionStorage: DBTransactionStorage @@ -175,7 +175,7 @@ class RequeryConfigurationTest { index = txnState.index stateStatus = Vault.StateStatus.UNCONSUMED contractStateClassName = DummyContract.SingleOwnerState::class.java.name - contractState = DummyContract.SingleOwnerState(owner = AnonymousParty(DUMMY_PUBKEY_1)).serialize(storageKryo()).bytes + contractState = DummyContract.SingleOwnerState(owner = AnonymousParty(DUMMY_PUBKEY_1)).serialize().bytes notaryName = txn.tx.notary!!.name.toString() notaryKey = txn.tx.notary!!.owningKey.toBase58String() recordedTime = Instant.now() diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index abce8ae027..f3c60fae4f 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -25,6 +25,8 @@ import net.corda.testing.node.MockKeyManagementService import net.corda.testing.node.TestClock import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.testNodeConfiguration +import net.corda.testing.initialiseTestSerialization +import net.corda.testing.resetTestSerialization import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name import org.junit.After @@ -67,6 +69,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { @Before fun setup() { + initialiseTestSerialization() countDown = CountDownLatch(1) smmHasRemovedAllFlows = CountDownLatch(1) calls = 0 @@ -114,6 +117,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { smmExecutor.shutdown() smmExecutor.awaitTermination(60, TimeUnit.SECONDS) database.close() + resetTestSerialization() } class TestState(val flowLogicRef: FlowLogicRef, val instant: Instant) : LinearState, SchedulableState { diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt index 59f71acde4..f9cd07f578 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt @@ -37,7 +37,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNull //TODO This needs to be merged into P2PMessagingTest as that creates a more realistic environment -class ArtemisMessagingTests { +class ArtemisMessagingTests : TestDependencyInjectionBase() { @Rule @JvmField val temporaryFolder = TemporaryFolder() val serverPort = freePort() diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt index 602481e6c6..9fa1d0e0e5 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt @@ -80,16 +80,18 @@ class InMemoryIdentityServiceTests { */ @Test fun `assert unknown anonymous key is unrecognised`() { - val rootKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val rootCert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootKey) - val txKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val service = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) - // TODO: Generate certificate with an EdDSA key rather than ECDSA - val identity = Party(CertificateAndKeyPair(rootCert, rootKey)) - val txIdentity = AnonymousParty(txKey.public) + withTestSerialization { + val rootKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val rootCert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootKey) + val txKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val service = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) + // TODO: Generate certificate with an EdDSA key rather than ECDSA + val identity = Party(CertificateAndKeyPair(rootCert, rootKey)) + val txIdentity = AnonymousParty(txKey.public) - assertFailsWith { - service.assertOwnership(identity, txIdentity) + assertFailsWith { + service.assertOwnership(identity, txIdentity) + } } } @@ -122,38 +124,40 @@ class InMemoryIdentityServiceTests { */ @Test fun `assert ownership`() { - val trustRoot = DUMMY_CA - val (alice, aliceTxIdentity) = createParty(ALICE.name, trustRoot) + withTestSerialization { + val trustRoot = DUMMY_CA + val (alice, aliceTxIdentity) = createParty(ALICE.name, trustRoot) - val certFactory = CertificateFactory.getInstance("X509") - val bobRootKey = Crypto.generateKeyPair() - val bobRoot = getTestPartyAndCertificate(BOB.name, bobRootKey.public) - val bobRootCert = bobRoot.certificate - val bobTxKey = Crypto.generateKeyPair() - val bobTxCert = X509Utilities.createCertificate(CertificateType.IDENTITY, bobRootCert, bobRootKey, BOB.name, bobTxKey.public) - val bobCertPath = certFactory.generateCertPath(listOf(bobTxCert.cert, bobRootCert.cert)) - val bob = PartyAndCertificate(BOB.name, bobRootKey.public, bobRootCert, bobCertPath) + val certFactory = CertificateFactory.getInstance("X509") + val bobRootKey = Crypto.generateKeyPair() + val bobRoot = getTestPartyAndCertificate(BOB.name, bobRootKey.public) + val bobRootCert = bobRoot.certificate + val bobTxKey = Crypto.generateKeyPair() + val bobTxCert = X509Utilities.createCertificate(CertificateType.IDENTITY, bobRootCert, bobRootKey, BOB.name, bobTxKey.public) + val bobCertPath = certFactory.generateCertPath(listOf(bobTxCert.cert, bobRootCert.cert)) + val bob = PartyAndCertificate(BOB.name, bobRootKey.public, bobRootCert, bobCertPath) - // Now we have identities, construct the service and let it know about both - val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert) - service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party) + // Now we have identities, construct the service and let it know about both + val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert) + service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party) - val anonymousBob = AnonymousPartyAndPath(AnonymousParty(bobTxKey.public),bobCertPath) - service.verifyAndRegisterAnonymousIdentity(anonymousBob, bob.party) + val anonymousBob = AnonymousPartyAndPath(AnonymousParty(bobTxKey.public),bobCertPath) + service.verifyAndRegisterAnonymousIdentity(anonymousBob, bob.party) - // Verify that paths are verified - service.assertOwnership(alice.party, aliceTxIdentity.party) - service.assertOwnership(bob.party, anonymousBob.party) - assertFailsWith { - service.assertOwnership(alice.party, anonymousBob.party) - } - assertFailsWith { - service.assertOwnership(bob.party, aliceTxIdentity.party) - } + // Verify that paths are verified + service.assertOwnership(alice.party, aliceTxIdentity.party) + service.assertOwnership(bob.party, anonymousBob.party) + assertFailsWith { + service.assertOwnership(alice.party, anonymousBob.party) + } + assertFailsWith { + service.assertOwnership(bob.party, aliceTxIdentity.party) + } - assertFailsWith { - val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) - service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), aliceTxIdentity.party) + assertFailsWith { + val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) + service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), aliceTxIdentity.party) + } } } diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt index 4ac1ff00cc..2d25f28741 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt @@ -7,12 +7,18 @@ import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.testing.node.MockNetwork import org.junit.After +import org.junit.Before import org.junit.Test import java.math.BigInteger import kotlin.test.assertEquals class InMemoryNetworkMapCacheTest { - private val mockNet = MockNetwork() + lateinit var mockNet: MockNetwork + + @Before + fun setUp() { + mockNet = MockNetwork() + } @After fun teardown() { diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 9ddc8f8b07..b303a8b425 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -2,12 +2,13 @@ package net.corda.node.services.persistence import com.google.common.primitives.Ints import net.corda.core.serialization.SerializedBytes -import net.corda.testing.LogHelper import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase +import net.corda.testing.LogHelper +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType @@ -24,7 +25,7 @@ internal fun CheckpointStorage.checkpoints(): List { return checkpoints } -class DBCheckpointStorageTests { +class DBCheckpointStorageTests : TestDependencyInjectionBase() { lateinit var checkpointStorage: DBCheckpointStorage lateinit var database: CordaPersistence diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt index 352e0f52ba..f6c14b3deb 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt @@ -8,11 +8,12 @@ import net.corda.core.crypto.testing.NullPublicKey import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.LogHelper import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.LogHelper +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat import org.junit.After @@ -21,7 +22,7 @@ import org.junit.Test import java.util.concurrent.TimeUnit import kotlin.test.assertEquals -class DBTransactionStorageTests { +class DBTransactionStorageTests : TestDependencyInjectionBase() { lateinit var database: CordaPersistence lateinit var transactionStorage: DBTransactionStorage diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt index 41c79f7373..bb0742468a 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt @@ -8,10 +8,11 @@ import io.atomix.copycat.server.storage.Storage import io.atomix.copycat.server.storage.StorageLevel import net.corda.core.getOrThrow import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.LogHelper import net.corda.node.services.network.NetworkMapService import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase +import net.corda.testing.LogHelper +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.freeLocalHostAndPort import net.corda.testing.node.makeTestDataSourceProperties import org.jetbrains.exposed.sql.Transaction @@ -22,7 +23,7 @@ import java.util.concurrent.CompletableFuture import kotlin.test.assertEquals import kotlin.test.assertTrue -class DistributedImmutableMapTests { +class DistributedImmutableMapTests : TestDependencyInjectionBase() { data class Member(val client: CopycatClient, val server: CopycatServer) lateinit var cluster: List diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt index 8bfb983ee9..858bae0539 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt @@ -6,6 +6,7 @@ import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase import net.corda.testing.LogHelper import net.corda.testing.MEGA_CORP +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.generateStateRef import net.corda.testing.node.makeTestDataSourceProperties import org.junit.After @@ -14,7 +15,7 @@ import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class PersistentUniquenessProviderTests { +class PersistentUniquenessProviderTests : TestDependencyInjectionBase() { val identity = MEGA_CORP val txID = SecureHash.randomSHA256() diff --git a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt index 2647bbbd89..08e4b913e3 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt @@ -34,7 +34,7 @@ import kotlin.test.assertFalse import kotlin.test.assertNull import kotlin.test.assertTrue -class NodeVaultServiceTest { +class NodeVaultServiceTest : TestDependencyInjectionBase() { lateinit var services: MockServices val vaultSvc: VaultService get() = services.vaultService lateinit var database: CordaPersistence diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index af9e8c009b..2b2f6a6038 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -15,8 +15,8 @@ import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.* import net.corda.core.utilities.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.toHexString import net.corda.node.services.database.HibernateConfiguration import net.corda.node.services.schema.NodeSchemaService @@ -46,7 +46,7 @@ import java.time.ZoneOffset import java.time.temporal.ChronoUnit import java.util.* -class VaultQueryTests { +class VaultQueryTests : TestDependencyInjectionBase() { lateinit var services: MockServices val vaultSvc: VaultService get() = services.vaultService diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt index 7f18cc6c81..9ab1cc3f8b 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt @@ -1,13 +1,8 @@ package net.corda.node.services.vault -import net.corda.testing.contracts.DummyDealContract import net.corda.contracts.asset.Cash import net.corda.contracts.asset.DUMMY_CASH_ISSUER -import net.corda.testing.contracts.fillWithSomeTestCash -import net.corda.testing.contracts.fillWithSomeTestDeals -import net.corda.testing.contracts.fillWithSomeTestLinearStates import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyLinearContract import net.corda.core.identity.AnonymousParty import net.corda.core.node.services.VaultService import net.corda.core.node.services.consumedStates @@ -15,12 +10,8 @@ import net.corda.core.node.services.unconsumedStates import net.corda.core.transactions.SignedTransaction import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.BOB -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.LogHelper -import net.corda.testing.MEGA_CORP -import net.corda.testing.MEGA_CORP_KEY +import net.corda.testing.* +import net.corda.testing.contracts.* import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties import org.assertj.core.api.Assertions.assertThat @@ -36,7 +27,7 @@ import kotlin.test.assertNull // TODO: Move this to the cash contract tests once mock services are further split up. -class VaultWithCashTest { +class VaultWithCashTest : TestDependencyInjectionBase() { lateinit var services: MockServices val vault: VaultService get() = services.vaultService lateinit var database: CordaPersistence diff --git a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt index 66dce242fe..8eb0df9e3b 100644 --- a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt +++ b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt @@ -79,7 +79,7 @@ class IRSDemoTest : IntegrationTestCategory { fun getFloatingLegFixCount(nodeApi: HttpApi) = getTrades(nodeApi)[0].calculation.floatingLegPaymentSchedule.count { it.value.rate.ratioUnit != null } fun getFixingDateObservable(config: FullNodeConfiguration): Observable { - val client = CordaRPCClient(config.rpcAddress!!) + val client = CordaRPCClient(config.rpcAddress!!, initialiseSerialization = false) val proxy = client.start("user", "password").proxy val vaultUpdates = proxy.vaultAndUpdates().second diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt index 0c65e25f2f..ac2ee52fc1 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt @@ -14,7 +14,6 @@ import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.node.services.ServiceInfo import net.corda.core.transactions.TransactionBuilder -import net.corda.testing.LogHelper import net.corda.core.utilities.ProgressTracker import net.corda.irs.flows.RatesFixFlow import net.corda.node.utilities.CordaPersistence @@ -34,7 +33,7 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse -class NodeInterestRatesTest { +class NodeInterestRatesTest : TestDependencyInjectionBase() { val TEST_DATA = NodeInterestRates.parseFile(""" LIBOR 2016-03-16 1M = 0.678 LIBOR 2016-03-16 2M = 0.685 @@ -202,7 +201,7 @@ class NodeInterestRatesTest { @Test fun `network tearoff`() { - val mockNet = MockNetwork() + val mockNet = MockNetwork(initialiseSerialization = false) val n1 = mockNet.createNotaryNode() val n2 = mockNet.createNode(n1.network.myAddress, advertisedServices = ServiceInfo(NodeInterestRates.Oracle.type)) n2.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt index 576dee021e..ce7e0b17ef 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt @@ -4,9 +4,6 @@ import net.corda.contracts.* import net.corda.core.contracts.* import net.corda.core.utilities.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.TEST_TX_TIME import net.corda.testing.* import net.corda.testing.node.MockServices import org.junit.Test @@ -200,7 +197,7 @@ fun createDummyIRS(irsSelect: Int): InterestRateSwap.State { } } -class IRSTests { +class IRSTests : TestDependencyInjectionBase() { val megaCorpServices = MockServices(MEGA_CORP_KEY) val miniCorpServices = MockServices(MINI_CORP_KEY) val notaryServices = MockServices(DUMMY_NOTARY_KEY) @@ -370,7 +367,7 @@ class IRSTests { val ld = LocalDate.of(2016, 3, 8) val bd = BigDecimal("0.0063518") - return ledger { + return ledger(initialiseSerialization = false) { transaction("Agreement") { output("irs post agreement") { singleIRS() } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } @@ -401,7 +398,7 @@ class IRSTests { @Test fun `ensure failure occurs when there are inbound states for an agreement command`() { val irs = singleIRS() - transaction { + transaction(initialiseSerialization = false) { input { irs } output("irs post agreement") { irs } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } @@ -414,7 +411,7 @@ class IRSTests { fun `ensure failure occurs when no events in fix schedule`() { val irs = singleIRS() val emptySchedule = mutableMapOf() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule)) } @@ -428,7 +425,7 @@ class IRSTests { fun `ensure failure occurs when no events in floating schedule`() { val irs = singleIRS() val emptySchedule = mutableMapOf() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule)) } @@ -441,7 +438,7 @@ class IRSTests { @Test fun `ensure notionals are non zero`() { val irs = singleIRS() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(irs.fixedLeg.copy(notional = irs.fixedLeg.notional.copy(quantity = 0))) } @@ -450,7 +447,7 @@ class IRSTests { this `fails with` "All notionals must be non zero" } - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(irs.fixedLeg.copy(notional = irs.floatingLeg.notional.copy(quantity = 0))) } @@ -464,7 +461,7 @@ class IRSTests { fun `ensure positive rate on fixed leg`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(fixedRate = FixedRate(PercentageRatioUnit("-0.1")))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -481,7 +478,7 @@ class IRSTests { fun `ensure same currency notionals`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.fixedLeg.notional.quantity, Currency.getInstance("JPY")))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -495,7 +492,7 @@ class IRSTests { fun `ensure notional amounts are equal`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.floatingLeg.notional.quantity + 1, irs.floatingLeg.notional.token))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -509,7 +506,7 @@ class IRSTests { fun `ensure trade date and termination date checks are done pt1`() { val irs = singleIRS() val modifiedIRS1 = irs.copy(fixedLeg = irs.fixedLeg.copy(terminationDate = irs.fixedLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS1 } @@ -519,7 +516,7 @@ class IRSTests { } val modifiedIRS2 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.floatingLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS2 } @@ -534,7 +531,7 @@ class IRSTests { val irs = singleIRS() val modifiedIRS3 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.fixedLeg.terminationDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS3 } @@ -545,7 +542,7 @@ class IRSTests { val modifiedIRS4 = irs.copy(floatingLeg = irs.floatingLeg.copy(effectiveDate = irs.fixedLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS4 } @@ -561,7 +558,7 @@ class IRSTests { val ld = LocalDate.of(2016, 3, 8) val bd = BigDecimal("0.0063518") - transaction { + transaction(initialiseSerialization = false) { output("irs post agreement") { singleIRS() } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } timeWindow(TEST_TX_TIME) @@ -574,7 +571,7 @@ class IRSTests { oldIRS.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))), oldIRS.common) - transaction { + transaction(initialiseSerialization = false) { input { oldIRS @@ -654,7 +651,7 @@ class IRSTests { val irs = singleIRS() - return ledger { + return ledger(initialiseSerialization = false) { transaction("Agreement") { output("irs post agreement1") { irs.copy( diff --git a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt index ddd7127651..a4a9854376 100644 --- a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt +++ b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt @@ -10,11 +10,11 @@ import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY import net.corda.flows.IssuerFlow -import net.corda.testing.driver.poll import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User import net.corda.testing.BOC +import net.corda.testing.driver.poll import net.corda.testing.node.NodeBasedTest import net.corda.traderdemo.flow.BuyerFlow import net.corda.traderdemo.flow.SellerFlow @@ -40,7 +40,7 @@ class TraderDemoTest : NodeBasedTest() { nodeA.registerInitiatedFlow(BuyerFlow::class.java) val (nodeARpc, nodeBRpc) = listOf(nodeA, nodeB).map { - val client = CordaRPCClient(it.configuration.rpcAddress!!) + val client = CordaRPCClient(it.configuration.rpcAddress!!, initialiseSerialization = false) client.start(demoUser[0].username, demoUser[0].password).proxy } diff --git a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt index 6c36d638c6..b39a2a2e22 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt @@ -126,11 +126,17 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List.() -> Unit ): LedgerDSL { - val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(services)) - dsl(ledgerDsl) - return ledgerDsl + if (initialiseSerialization) initialiseTestSerialization() + try { + val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(services)) + dsl(ledgerDsl) + return ledgerDsl + } finally { + if (initialiseSerialization) resetTestSerialization() + } } /** @@ -141,8 +147,9 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List.() -> EnforceVerifyOrFail -) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) } +) = ledger(initialiseSerialization = initialiseSerialization) { this.transaction(transactionLabel, transactionBuilder, dsl) } fun testNodeConfiguration( baseDirectory: Path, diff --git a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt index ff78e3deeb..a85426fefe 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt @@ -5,6 +5,7 @@ import net.corda.client.mock.Generator import net.corda.client.mock.generateOrFail import net.corda.client.mock.int import net.corda.client.mock.string +import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.internal.div @@ -22,6 +23,7 @@ import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.RPCApi import net.corda.nodeapi.User +import net.corda.nodeapi.serialization.KRYO_RPC_CLIENT_CONTEXT import net.corda.testing.driver.* import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.TransportConfiguration @@ -224,6 +226,7 @@ fun rpcDriver( debugPortAllocation: PortAllocation = globalDebugPortAllocation, systemProperties: Map = emptyMap(), useTestClock: Boolean = false, + initialiseSerialization: Boolean = true, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false), startNodesInProcess: Boolean = false, dsl: RPCDriverExposedDSLInterface.() -> A @@ -241,7 +244,8 @@ fun rpcDriver( ) ), coerce = { it }, - dsl = dsl + dsl = dsl, + initialiseSerialization = initialiseSerialization ) private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 { @@ -510,7 +514,8 @@ class RandomRpcUser { val hostAndPort = args[1].parseNetworkHostAndPort() val username = args[2] val password = args[3] - val handle = RPCClient(hostAndPort, null).start(rpcClass, username, password) + CordaRPCClient.initialiseSerialization() + val handle = RPCClient(hostAndPort, null, serializationContext = KRYO_RPC_CLIENT_CONTEXT).start(rpcClass, username, password) val callGenerators = rpcClass.declaredMethods.map { method -> Generator.sequence(method.parameters.map { generatorStore[it.type] ?: throw Exception("No generator for ${it.type}") diff --git a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt new file mode 100644 index 0000000000..5a6f381cc1 --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -0,0 +1,140 @@ +package net.corda.testing + +import net.corda.client.rpc.serialization.KryoClientSerializationScheme +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.nodeapi.serialization.* + +fun withTestSerialization(block: () -> T): T { + initialiseTestSerialization() + try { + return block() + } finally { + resetTestSerialization() + } +} + +fun initialiseTestSerialization() { + // Check that everything is configured for testing with mutable delegating instances. + try { + check(SerializationDefaults.SERIALIZATION_FACTORY is TestSerializationFactory) { + "Found non-test serialization configuration: ${SerializationDefaults.SERIALIZATION_FACTORY}" + } + } catch(e: IllegalStateException) { + SerializationDefaults.SERIALIZATION_FACTORY = TestSerializationFactory() + } + try { + check(SerializationDefaults.P2P_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.P2P_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.RPC_SERVER_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.RPC_SERVER_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.RPC_CLIENT_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.RPC_CLIENT_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.STORAGE_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.STORAGE_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.CHECKPOINT_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.CHECKPOINT_CONTEXT = TestSerializationContext() + } + + // Check that the previous test, if there was one, cleaned up after itself. + // IF YOU SEE THESE MESSAGES, THEN IT MEANS A TEST HAS NOT CALLED resetTestSerialization() + check((SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.SERIALIZATION_FACTORY}" }) + check((SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.P2P_CONTEXT}" }) + check((SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.RPC_SERVER_CONTEXT}" }) + check((SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.RPC_CLIENT_CONTEXT}" }) + check((SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.STORAGE_CONTEXT}" }) + check((SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.CHECKPOINT_CONTEXT}" }) + + // Now configure all the testing related delegates. + (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme()) + registerScheme(KryoServerSerializationScheme()) + } + (SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate = KRYO_P2P_CONTEXT + (SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT + (SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_CLIENT_CONTEXT + (SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate = KRYO_STORAGE_CONTEXT + (SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate = KRYO_CHECKPOINT_CONTEXT +} + +fun resetTestSerialization() { + (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = null + (SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null +} + +class TestSerializationFactory : SerializationFactory { + var delegate: SerializationFactory? = null + set(value) { + field = value + stackTrace = Exception().stackTrace.asList() + } + private var stackTrace: List? = null + + override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + return delegate!!.deserialize(byteSequence, clazz, context) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + return delegate!!.serialize(obj, context) + } +} + +class TestSerializationContext : SerializationContext { + var delegate: SerializationContext? = null + set(value) { + field = value + stackTrace = Exception().stackTrace.asList() + } + private var stackTrace: List? = null + + override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" + + override val preferedSerializationVersion: ByteSequence + get() = delegate!!.preferedSerializationVersion + override val deserializationClassLoader: ClassLoader + get() = delegate!!.deserializationClassLoader + override val whitelist: ClassWhitelist + get() = delegate!!.whitelist + override val properties: Map + get() = delegate!!.properties + override val objectReferencesEnabled: Boolean + get() = delegate!!.objectReferencesEnabled + override val useCase: SerializationContext.UseCase + get() = delegate!!.useCase + + override fun withProperty(property: Any, value: Any): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withProperty(property, value) } + } + + override fun withoutReferences(): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withoutReferences() } + } + + override fun withClassLoader(classLoader: ClassLoader): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withClassLoader(classLoader) } + } + + override fun withWhitelisted(clazz: Class<*>): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) } + } +} diff --git a/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt b/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt new file mode 100644 index 0000000000..549cd2ac6d --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt @@ -0,0 +1,19 @@ +package net.corda.testing + +import org.junit.After +import org.junit.Before + +/** + * The beginnings of somewhere to inject implementations for unit tests. + */ +abstract class TestDependencyInjectionBase { + @Before + fun initialiseSerialization() { + initialiseTestSerialization() + } + + @After + fun resetInitialisation() { + resetTestSerialization() + } +} \ No newline at end of file diff --git a/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt b/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt index e279266d41..d75150b844 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt @@ -40,6 +40,8 @@ import net.corda.testing.BOB import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_NOTARY import net.corda.testing.node.MOCK_VERSION_INFO +import net.corda.testing.initialiseTestSerialization +import net.corda.testing.resetTestSerialization import okhttp3.OkHttpClient import okhttp3.Request import org.bouncycastle.asn1.x500.X500Name @@ -187,7 +189,7 @@ sealed class NodeHandle { val nodeThread: Thread ) : NodeHandle() - fun rpcClientToNode(): CordaRPCClient = CordaRPCClient(configuration.rpcAddress!!) + fun rpcClientToNode(): CordaRPCClient = CordaRPCClient(configuration.rpcAddress!!, initialiseSerialization = false) } data class WebserverHandle( @@ -250,6 +252,7 @@ fun driver( debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), systemProperties: Map = emptyMap(), useTestClock: Boolean = false, + initialiseSerialization: Boolean = true, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true), startNodesInProcess: Boolean = false, dsl: DriverDSLExposedInterface.() -> A @@ -278,9 +281,11 @@ fun driver( */ fun genericDriver( driverDsl: D, + initialiseSerialization: Boolean = true, coerce: (D) -> DI, dsl: DI.() -> A ): A { + if (initialiseSerialization) initialiseTestSerialization() val shutdownHook = addShutdownHook(driverDsl::shutdown) try { driverDsl.start() @@ -291,6 +296,7 @@ fun genericD } finally { driverDsl.shutdown() shutdownHook.cancel() + if (initialiseSerialization) resetTestSerialization() } } @@ -511,7 +517,7 @@ class DriverDSL( } private fun establishRpc(nodeAddress: NetworkHostAndPort, sslConfig: SSLConfiguration, processDeathFuture: ListenableFuture): ListenableFuture { - val client = CordaRPCClient(nodeAddress, sslConfig) + val client = CordaRPCClient(nodeAddress, sslConfig, initialiseSerialization = false) val connectionFuture = poll(executorService, "RPC connection") { try { client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) @@ -769,7 +775,7 @@ class DriverDSL( writeConfig(nodeConf.baseDirectory, "node.conf", config) val clock: Clock = if (nodeConf.useTestClock) TestClock() else NodeClock() // TODO pass the version in? - val node = Node(nodeConf, nodeConf.calculateServices(), MOCK_VERSION_INFO, clock) + val node = Node(nodeConf, nodeConf.calculateServices(), MOCK_VERSION_INFO, clock, initialiseSerialization = false) node.start() val nodeThread = thread(name = nodeConf.myLegalName.commonName) { node.run() diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt index 28e856a203..34b3b53401 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -5,11 +5,11 @@ import com.google.common.jimfs.Jimfs import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.* import net.corda.core.crypto.CertificateAndKeyPair import net.corda.core.crypto.cert import net.corda.core.crypto.entropyToKeyPair import net.corda.core.crypto.random63BitValue +import net.corda.core.getOrThrow import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.createDirectories import net.corda.core.internal.createDirectory @@ -19,11 +19,11 @@ import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.CordaPluginRegistry import net.corda.core.node.ServiceEntry -import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.node.WorldMapLocation import net.corda.core.node.services.IdentityService import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.ServiceInfo +import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import net.corda.node.internal.AbstractNode import net.corda.node.services.config.NodeConfiguration @@ -67,7 +67,8 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, private val threadPerNode: Boolean = false, servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), - private val defaultFactory: Factory = MockNetwork.DefaultFactory) { + private val defaultFactory: Factory = MockNetwork.DefaultFactory, + private val initialiseSerialization: Boolean = true) { val nextNodeId get() = _nextNodeId private var _nextNodeId = 0 @@ -85,6 +86,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, val nodes: List = _nodes init { + if (initialiseSerialization) initialiseTestSerialization() filesystem.getPath("/nodes").createDirectory() } @@ -396,6 +398,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, fun stopNodes() { nodes.forEach { if (it.started) it.stop() } + if (initialiseSerialization) resetTestSerialization() } // Test method to block until all scheduled activity, active flows diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt index 7aa63a557f..a59e94a238 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt @@ -3,12 +3,14 @@ package net.corda.testing.node import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors.listeningDecorator -import net.corda.core.* import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.appendToCommonName import net.corda.core.crypto.commonName +import net.corda.core.flatMap +import net.corda.core.getOrThrow import net.corda.core.internal.createDirectories import net.corda.core.internal.div +import net.corda.core.map import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.core.utilities.WHITESPACE @@ -23,6 +25,7 @@ import net.corda.node.utilities.ServiceIdentityGenerator import net.corda.nodeapi.User import net.corda.nodeapi.config.parseAs import net.corda.testing.DUMMY_MAP +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.driver.addressMustNotBeBoundFuture import net.corda.testing.getFreeLocalPorts import org.apache.logging.log4j.Level @@ -39,7 +42,7 @@ import kotlin.concurrent.thread * purposes. Use the driver if you need to run the nodes in separate processes otherwise this class will suffice. */ // TODO Some of the logic here duplicates what's in the driver -abstract class NodeBasedTest { +abstract class NodeBasedTest : TestDependencyInjectionBase() { @Rule @JvmField val tempFolder = TemporaryFolder() @@ -161,7 +164,7 @@ abstract class NodeBasedTest { val parsedConfig = config.parseAs() val node = Node(parsedConfig, parsedConfig.calculateServices(), MOCK_VERSION_INFO.copy(platformVersion = platformVersion), - if (parsedConfig.useTestClock) TestClock() else NodeClock()) + if (parsedConfig.useTestClock) TestClock() else NodeClock(), initialiseSerialization = false) node.start() nodes += node thread(name = legalName.commonName) { diff --git a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt index f5c54a013a..e8eb1cb910 100644 --- a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt +++ b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt @@ -1,9 +1,13 @@ package net.corda.verifier +import com.esotericsoftware.kryo.pool.KryoPool import com.typesafe.config.Config import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions import net.corda.core.internal.div +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor @@ -14,6 +18,10 @@ import net.corda.nodeapi.VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME import net.corda.nodeapi.config.NodeSSLConfiguration import net.corda.nodeapi.config.getValue import net.corda.nodeapi.internal.addShutdownHook +import net.corda.nodeapi.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.serialization.KRYO_P2P_CONTEXT +import net.corda.nodeapi.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.serialization.SerializationFactoryImpl import org.apache.activemq.artemis.api.core.client.ActiveMQClient import java.nio.file.Path import java.nio.file.Paths @@ -55,6 +63,7 @@ class Verifier { session.close() sessionFactory.close() } + initialiseSerialization() val consumer = session.createConsumer(VERIFICATION_REQUESTS_QUEUE_NAME) val replyProducer = session.createProducer() consumer.setMessageHandler { @@ -77,5 +86,26 @@ class Verifier { log.info("Verifier started") Thread.sleep(Long.MAX_VALUE) } + + private fun initialiseSerialization() { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoVerifierSerializationScheme) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + } + } + + object KryoVerifierSerializationScheme : AbstractKryoSerializationScheme() { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence.equals(KryoHeaderV0_1) && target == SerializationContext.UseCase.P2P + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } + + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } } } \ No newline at end of file