diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt index f920a8081d..08a9924d0c 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt @@ -21,11 +21,16 @@ class AMQPClientSerializationScheme( cordappSerializationWhitelists: Set, serializerFactoriesForContexts: MutableMap ) : AbstractAMQPSerializationScheme(cordappCustomSerializers, cordappSerializationWhitelists, serializerFactoriesForContexts) { - constructor(cordapps: List) : this(cordapps.customSerializers, cordapps.serializationWhitelists, AccessOrderLinkedHashMap(128).toSynchronised()) - constructor(cordapps: List, serializerFactoriesForContexts: MutableMap) : this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts) + constructor(cordapps: List) : this(cordapps.customSerializers, cordapps.serializationWhitelists) + constructor(cordapps: List, serializerFactoriesForContexts: MutableMap) + : this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts) + constructor( + cordappCustomSerializers: Set>, + cordappSerializationWhitelists: Set + ) : this(cordappCustomSerializers, cordappSerializationWhitelists, createDefaultSerializerFactoryCache()) @Suppress("UNUSED") - constructor() : this(emptySet(), emptySet(), AccessOrderLinkedHashMap(128).toSynchronised()) + constructor() : this(emptySet(), emptySet()) companion object { /** Call from main only. */ @@ -44,6 +49,10 @@ class AMQPClientSerializationScheme( rpcServerContext = AMQP_RPC_SERVER_CONTEXT ) } + + private fun createDefaultSerializerFactoryCache(): MutableMap { + return AccessOrderLinkedHashMap(128).toSynchronised() + } } override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { diff --git a/constants.properties b/constants.properties index fa61a31d9a..6e4f55adfa 100644 --- a/constants.properties +++ b/constants.properties @@ -20,7 +20,7 @@ quasarVersion11=0.8.0_r3_rc1 jdkClassifier11=jdk11 proguardVersion=6.1.1 bouncycastleVersion=1.60 -classgraphVersion=4.8.41 +classgraphVersion=4.8.58 disruptorVersion=3.4.2 typesafeConfigVersion=1.3.4 jsr305Version=3.0.2 diff --git a/core/build.gradle b/core/build.gradle index 77d566bff2..6f2193f4b0 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -76,6 +76,8 @@ dependencies { compile group: "io.github.classgraph", name: "classgraph", version: class_graph_version + testCompile "org.ow2.asm:asm:$asm_version" + // JDK11: required by Quasar at run-time testRuntimeOnly "com.esotericsoftware:kryo:4.0.2" diff --git a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt index 530c1383aa..00c62f9104 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt @@ -1,7 +1,9 @@ package net.corda.core.internal import io.github.classgraph.ClassGraph +import io.github.classgraph.ClassInfo import net.corda.core.StubOutForDJVM +import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.attachmentScheme /** * Creates instances of all the classes in the classpath of the provided classloader, which implement the interface of the provided class. @@ -17,16 +19,17 @@ import net.corda.core.StubOutForDJVM */ @StubOutForDJVM fun createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class): Set { - return ClassGraph().addClassLoader(classloader) - .enableClassInfo() - .pooledScan() - .use { - it.getClassesImplementing(clazz.name) - .filterNot { it.isAbstract } - .map { classloader.loadClass(it.name).asSubclass(clazz) } - .map { it.kotlin.objectOrNewInstance() } - .toSet() - } + return ClassGraph().overrideClassLoaders(classloader) + .enableURLScheme(attachmentScheme) + .ignoreParentClassLoaders() + .enableClassInfo() + .pooledScan() + .use { result -> + result.getClassesImplementing(clazz.name) + .filterNot(ClassInfo::isAbstract) + .map { classloader.loadClass(it.name).asSubclass(clazz) } + .mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() } + } } fun executeWithThreadContextClassLoader(classloader: ClassLoader, fn: () -> T): T { diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt index 0e5f1c5053..80a8c59297 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt @@ -351,7 +351,7 @@ object AttachmentsClassLoaderBuilder { * This will not be exposed as an API. */ object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory { - private const val attachmentScheme = "attachment" + internal const val attachmentScheme = "attachment" // TODO - what happens if this grows too large? private val loadedAttachments = mutableMapOf().toSynchronised() diff --git a/core/src/test/kotlin/net/corda/core/internal/AsmTools.kt b/core/src/test/kotlin/net/corda/core/internal/AsmTools.kt new file mode 100644 index 0000000000..561f3e41e9 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/AsmTools.kt @@ -0,0 +1,31 @@ +@file:JvmName("AsmTools") +package net.corda.core.internal + +import org.objectweb.asm.ClassReader +import org.objectweb.asm.ClassVisitor +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.ClassWriter.COMPUTE_MAXS +import org.objectweb.asm.Opcodes.ASM7 + +val String.asInternalName: String get() = replace('.', '/') +val Class<*>.resourceName: String get() = "${name.asInternalName}.class" +val Class<*>.byteCode: ByteArray get() = classLoader.getResourceAsStream(resourceName)!!.use { + it.readBytes() +} + +fun Class<*>.renameTo(newName: String): ByteArray { + return byteCode.accept { w -> RenamingWriter(newName, w) } +} + +fun ByteArray.accept(visitor: (ClassVisitor) -> ClassVisitor): ByteArray { + return ClassWriter(COMPUTE_MAXS).let { writer -> + ClassReader(this).accept(visitor(writer), 0) + writer.toByteArray() + } +} + +private class RenamingWriter(private val newName: String, visitor: ClassVisitor) : ClassVisitor(ASM7, visitor) { + override fun visit(version: Int, access: Int, name: String, signature: String?, superName: String?, interfaces: Array?) { + super.visit(version, access, newName, signature, superName, interfaces) + } +} diff --git a/core/src/test/kotlin/net/corda/core/internal/ClassLoadingUtilsTest.kt b/core/src/test/kotlin/net/corda/core/internal/ClassLoadingUtilsTest.kt index b77d6998e7..8378ff3213 100644 --- a/core/src/test/kotlin/net/corda/core/internal/ClassLoadingUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/internal/ClassLoadingUtilsTest.kt @@ -1,23 +1,62 @@ package net.corda.core.internal import com.nhaarman.mockito_kotlin.mock +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.ContractClassName +import net.corda.core.crypto.SecureHash +import net.corda.core.identity.Party +import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory +import net.corda.core.serialization.internal.AttachmentsClassLoader import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Assert.assertEquals import org.junit.Test +import java.io.ByteArrayOutputStream +import java.net.URLClassLoader +import java.security.PublicKey +import java.util.jar.JarOutputStream +import java.util.jar.Manifest +import java.util.zip.Deflater.NO_COMPRESSION +import java.util.zip.ZipEntry +import java.util.zip.ZipEntry.DEFLATED +import java.util.zip.ZipEntry.STORED class ClassLoadingUtilsTest { + companion object { + const val STANDALONE_CLASS_NAME = "com.example.StandaloneClassWithEmptyConstructor" + const val PROGRAM_ID: ContractClassName = "net.corda.core.internal.DummyContract" + val contractAttachmentId = SecureHash.randomSHA256() + + fun directoryEntry(internalName: String) = ZipEntry("$internalName/").apply { + method = STORED + compressedSize = 0 + size = 0 + crc = 0 + } + + fun classEntry(internalName: String) = ZipEntry("$internalName.class").apply { + method = DEFLATED + } + + init { + // Register the "attachment://" URL scheme. + // You may only register new schemes ONCE per JVM! + AttachmentsClassLoader + } + } private val temporaryClassLoader = mock() - interface BaseInterface {} + interface BaseInterface - interface BaseInterface2 {} + interface BaseInterface2 - class ConcreteClassWithEmptyConstructor: BaseInterface {} + class ConcreteClassWithEmptyConstructor: BaseInterface abstract class AbstractClass: BaseInterface - class ConcreteClassWithNonEmptyConstructor(private val someData: Int): BaseInterface2 {} + @Suppress("unused") + class ConcreteClassWithNonEmptyConstructor(private val someData: Int): BaseInterface2 @Test fun predicateClassAreLoadedSuccessfully() { @@ -25,8 +64,9 @@ class ClassLoadingUtilsTest { val classNames = classes.map { it.javaClass.name } - assertThat(classNames).contains(ConcreteClassWithEmptyConstructor::class.java.name) - assertThat(classNames).doesNotContain(AbstractClass::class.java.name) + assertThat(classNames) + .contains(ConcreteClassWithEmptyConstructor::class.java.name) + .doesNotContain(AbstractClass::class.java.name) } @Test(expected = IllegalArgumentException::class) @@ -54,4 +94,42 @@ class ClassLoadingUtilsTest { assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader) } -} \ No newline at end of file + @Test + fun `test locating classes inside attachment`() { + val jarData = with(ByteArrayOutputStream()) { + val internalName = STANDALONE_CLASS_NAME.asInternalName + JarOutputStream(this, Manifest()).use { + it.setLevel(NO_COMPRESSION) + it.setMethod(DEFLATED) + it.putNextEntry(directoryEntry("com")) + it.putNextEntry(directoryEntry("com/example")) + it.putNextEntry(classEntry(internalName)) + it.write(TemplateClassWithEmptyConstructor::class.java.renameTo(internalName)) + } + toByteArray() + } + val attachment = signedAttachment(jarData) + val url = AttachmentURLStreamHandlerFactory.toUrl(attachment) + + URLClassLoader(arrayOf(url)).use { cordappClassLoader -> + val standaloneClass = createInstancesOfClassesImplementing(cordappClassLoader, BaseInterface::class.java) + .map(Any::javaClass) + .single() + assertEquals(STANDALONE_CLASS_NAME, standaloneClass.name) + assertEquals(cordappClassLoader, standaloneClass.classLoader) + } + } + + private fun signedAttachment(data: ByteArray, vararg parties: Party) = ContractAttachment.create( + object : AbstractAttachment({ data }, "test") { + override val id: SecureHash get() = contractAttachmentId + + override val signerKeys: List get() = parties.map(Party::owningKey) + }, PROGRAM_ID, signerKeys = parties.map(Party::owningKey) + ) +} + +// Our dummy attachment will contain a class that is created from this one. +// This is because our attachment must contain a class that DOES NOT exist +// inside the application classloader. +class TemplateClassWithEmptyConstructor : ClassLoadingUtilsTest.BaseInterface diff --git a/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CurrantsySerializer.kt b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CurrantsySerializer.kt new file mode 100644 index 0000000000..22c3e2056f --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CurrantsySerializer.kt @@ -0,0 +1,21 @@ +package net.corda.contracts.serialization.custom + +import net.corda.core.serialization.SerializationCustomSerializer + +@Suppress("unused") +class CurrantsySerializer : SerializationCustomSerializer { + data class Proxy(val currants: Long) + + override fun fromProxy(proxy: Proxy): Currantsy = Currantsy(proxy.currants) + override fun toProxy(obj: Currantsy) = Proxy(obj.currants) +} + +data class Currantsy(val currants: Long) : Comparable { + override fun toString(): String { + return "$currants juicy currants" + } + + override fun compareTo(other: Currantsy): Int { + return currants.compareTo(other.currants) + } +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerContract.kt new file mode 100644 index 0000000000..7f49606005 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerContract.kt @@ -0,0 +1,39 @@ +package net.corda.contracts.serialization.custom + +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction + +@Suppress("unused") +class CustomSerializerContract : Contract { + companion object { + const val MAX_CURRANT = 2000 + } + + override fun verify(tx: LedgerTransaction) { + val currantsyData = tx.outputsOfType(CurrantsyState::class.java) + require(currantsyData.isNotEmpty()) { + "Requires at least one currantsy state" + } + + currantsyData.forEach { + require(it.currantsy.currants in 0..MAX_CURRANT) { + "Too many currants! ${it.currantsy.currants} is unraisinable!" + } + } + } + + @Suppress("CanBeParameter", "MemberVisibilityCanBePrivate") + class CurrantsyState(val owner: AbstractParty, val currantsy: Currantsy) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return currantsy.toString() + } + } + + class Purchase : CommandData +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerRegistry.kt b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerRegistry.kt new file mode 100644 index 0000000000..e2b35f6b99 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/serialization/custom/CustomSerializerRegistry.kt @@ -0,0 +1,7 @@ +package net.corda.contracts.serialization.custom + +import net.corda.core.serialization.SerializationWhitelist + +class CustomSerializerRegistry : SerializationWhitelist { + override val whitelist: List> = listOf(Currantsy::class.java) +} diff --git a/node/src/integration-test/kotlin/net/corda/flows/serialization/custom/CustomSerializerFlow.kt b/node/src/integration-test/kotlin/net/corda/flows/serialization/custom/CustomSerializerFlow.kt new file mode 100644 index 0000000000..e3491748e6 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/flows/serialization/custom/CustomSerializerFlow.kt @@ -0,0 +1,30 @@ +package net.corda.flows.serialization.custom + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.serialization.custom.Currantsy +import net.corda.contracts.serialization.custom.CustomSerializerContract.CurrantsyState +import net.corda.contracts.serialization.custom.CustomSerializerContract.Purchase +import net.corda.core.contracts.Command +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.transactions.TransactionBuilder + +@InitiatingFlow +@StartableByRPC +class CustomSerializerFlow( + private val purchase: Currantsy +) : FlowLogic() { + @Suspendable + override fun call(): SecureHash { + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val stx = serviceHub.signInitialTransaction( + TransactionBuilder(notary) + .addOutputState(CurrantsyState(ourIdentity, purchase)) + .addCommand(Command(Purchase(), ourIdentity.owningKey)) + ) + stx.verify(serviceHub, checkSufficientSignatures = false) + return stx.id + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt new file mode 100644 index 0000000000..e725316207 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt @@ -0,0 +1,56 @@ +package net.corda.node + +import net.corda.client.rpc.CordaRPCClient +import net.corda.contracts.serialization.custom.Currantsy +import net.corda.core.contracts.TransactionVerificationException +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.flows.serialization.custom.CustomSerializerFlow +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.CustomCordapp +import net.corda.testing.node.internal.cordappWithPackages +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import org.junit.jupiter.api.assertThrows + +@Suppress("FunctionName") +class ContractWithCustomSerializerTest { + companion object { + const val CURRANTS = 5000L + } + + @Test + fun `flow with custom serializer by rpc`() { + val user = User("u", "p", setOf(Permissions.all())) + driver(DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = false, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + cordappsForAllNodes = listOf( + cordappWithPackages("net.corda.flows.serialization.custom"), + CustomCordapp( + packages = setOf("net.corda.contracts.serialization.custom"), + name = "has-custom-serializer" + ).signed() + ) + )) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val ex = assertThrows { + CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .proxy + .startFlow(::CustomSerializerFlow, Currantsy(CURRANTS)) + .returnValue + .getOrThrow() + } + assertThat(ex).hasMessageContaining("Too many currants! $CURRANTS is unraisinable!") + } + } +} \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt index cba5d9a523..e50d11226a 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt @@ -4,6 +4,7 @@ import io.github.classgraph.ClassGraph import net.corda.core.internal.* import net.corda.core.internal.cordapp.CordappImpl import net.corda.core.internal.cordapp.set +import net.corda.core.serialization.SerializationWhitelist import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug import net.corda.testing.core.internal.JarSignatureTestUtils.generateKey @@ -58,11 +59,20 @@ data class CustomCordapp( classGraph.whitelistClasses(*classes.map { it.name }.toTypedArray()) } - classGraph.pooledScan().use { scanResult -> + classGraph.enableClassInfo().pooledScan().use { scanResult -> + val whitelistService = SerializationWhitelist::class.java.name + val whitelists = scanResult.getClassesImplementing(whitelistService) + JarOutputStream(file.outputStream()).use { jos -> jos.addEntry(testEntry(JarFile.MANIFEST_NAME)) { createTestManifest(name, versionId, targetPlatformVersion).write(jos) } + if (whitelists.isNotEmpty()) { + jos.addEntry(directoryEntry("META-INF/services")) + jos.addEntry(testEntry("META-INF/services/$whitelistService")) { + jos.write(whitelists.names.joinToString(separator = "\r\n").toByteArray()) + } + } // The same resource may be found in different locations (this will happen when running from gradle) so just // pick the first one found. @@ -114,6 +124,16 @@ data class CustomCordapp( return ZipEntry(name).setCreationTime(epochFileTime).setLastAccessTime(epochFileTime).setLastModifiedTime(epochFileTime) } + private fun directoryEntry(name: String): ZipEntry { + val directoryName = if (name.endsWith('/')) name else "$name/" + return testEntry(directoryName).apply { + method = ZipEntry.STORED + compressedSize = 0 + size = 0 + crc = 0 + } + } + data class SigningInfo(val keyStorePath: Path?, val numberOfSignatures: Int, val keyAlgorithm: String) companion object { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt index b22e7bc7e4..bb11bac824 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt @@ -35,7 +35,6 @@ import net.corda.core.node.NetworkParameters import net.corda.core.node.NotaryInfo import net.corda.core.node.services.NetworkMapCache import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.millis @@ -84,6 +83,7 @@ import rx.schedulers.Schedulers import java.io.File import java.net.ConnectException import java.net.URL +import java.net.URLClassLoader import java.nio.file.Path import java.security.cert.X509Certificate import java.time.Duration @@ -142,6 +142,12 @@ class DriverDSLImpl( private lateinit var _notaries: CordaFuture> override val notaryHandles: List get() = _notaries.getOrThrow() + override val cordappsClassLoader: ClassLoader? = if (!startNodesInProcess) { + createCordappsClassLoader(cordappsForAllNodes) + } else { + null + } + interface Waitable { @Throws(InterruptedException::class) fun waitFor() @@ -193,6 +199,7 @@ class DriverDSLImpl( } _shutdownManager?.shutdown() _executorService?.shutdownNow() + (cordappsClassLoader as? AutoCloseable)?.close() } private fun establishRpc(config: NodeConfig, processDeathFuture: CordaFuture): CordaFuture { @@ -982,6 +989,13 @@ class DriverDSLImpl( return config } + private fun createCordappsClassLoader(cordapps: Collection?): ClassLoader? { + if (cordapps == null || cordapps.isEmpty()) { + return null + } + return URLClassLoader(cordapps.map { it.jarFile.toUri().toURL() }.toTypedArray()) + } + private operator fun Config.plus(property: Pair) = withValue(property.first, ConfigValueFactory.fromAnyRef(property.second)) /** @@ -1069,6 +1083,8 @@ interface InternalDriverDSL : DriverDSL { val shutdownManager: ShutdownManager + val cordappsClassLoader: ClassLoader? + fun baseDirectory(nodeName: String): Path = baseDirectory(CordaX500Name.parse(nodeName)) /** @@ -1113,7 +1129,7 @@ fun genericDriver( coerce: (D) -> DI, dsl: DI.() -> A ): A { - val serializationEnv = setDriverSerialization() + val serializationEnv = setDriverSerialization(driverDsl.cordappsClassLoader) val shutdownHook = addShutdownHook(driverDsl::shutdown) try { driverDsl.start() diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt index 33df4fb933..478f82271e 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt @@ -282,15 +282,18 @@ fun DriverDSL.assertCheckpoints(name: CordaX500Name, expected: Long) { /** * Should only be used by Driver and MockNode. */ -fun setDriverSerialization(): AutoCloseable? { +fun setDriverSerialization(classLoader: ClassLoader?): AutoCloseable? { return if (_allEnabledSerializationEnvs.isEmpty()) { - DriverSerializationEnvironment().enable() + DriverSerializationEnvironment(classLoader).enable() } else { null } } -private class DriverSerializationEnvironment : SerializationEnvironment by createTestSerializationEnv(), AutoCloseable { +fun setDriverSerialization(): AutoCloseable? = setDriverSerialization(null) + +private class DriverSerializationEnvironment(classLoader: ClassLoader?) + : SerializationEnvironment by createTestSerializationEnv(classLoader), AutoCloseable { fun enable() = apply { _driverSerializationEnv.set(this) } override fun close() { _driverSerializationEnv.set(null) @@ -303,6 +306,8 @@ fun JarOutputStream.addEntry(entry: ZipEntry, input: InputStream) { addEntry(entry) { input.use { it.copyTo(this) } } } +fun JarOutputStream.addEntry(entry: ZipEntry) = addEntry(entry) {} + inline fun JarOutputStream.addEntry(entry: ZipEntry, write: () -> Unit) { putNextEntry(entry) write() diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt index 24ecaa2386..622ab76245 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt @@ -1,20 +1,35 @@ package net.corda.testing.internal import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme +import net.corda.core.internal.createInstancesOfClassesImplementing +import net.corda.core.serialization.SerializationCustomSerializer +import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KryoCheckpointSerializer import net.corda.serialization.internal.* import net.corda.testing.common.internal.asContextEnv +import java.util.ServiceLoader import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService val inVMExecutors = ConcurrentHashMap() fun createTestSerializationEnv(): SerializationEnvironment { + return createTestSerializationEnv(null) +} + +fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { + val clientSerializationScheme = if (classLoader != null) { + val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) + val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() + AMQPClientSerializationScheme(customSerializers, serializationWhitelists) + } else { + AMQPClientSerializationScheme(emptyList()) + } val factory = SerializationFactoryImpl().apply { - registerScheme(AMQPClientSerializationScheme(emptyList())) + registerScheme(clientSerializationScheme) registerScheme(AMQPServerSerializationScheme(emptyList())) } return SerializationEnvironment.with(