mirror of
https://github.com/corda/corda.git
synced 2024-12-19 21:17:58 +00:00
CORDA-3464: Also scan attachment:// URLs for custom serializers. (#5769)
* CORDA-3464: Also scan attachment:// URLs for custom serializers. * Only scan the given classloader - ignore this classloader's parents. * Upgrade to ClassGraph 4.8.58 - for "robustness fixes". * Register the attachment:// URL scheme using AttachmentsClassLoader. * Add integration test for custom serializer in contract state. * Rename Currancy -> Currantsy, just to make the point.
This commit is contained in:
parent
2abf22ccf9
commit
5a41ec9b82
@ -21,11 +21,16 @@ class AMQPClientSerializationScheme(
|
||||
cordappSerializationWhitelists: Set<SerializationWhitelist>,
|
||||
serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>
|
||||
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, cordappSerializationWhitelists, serializerFactoriesForContexts) {
|
||||
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
|
||||
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts)
|
||||
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, cordapps.serializationWhitelists)
|
||||
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>)
|
||||
: this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts)
|
||||
constructor(
|
||||
cordappCustomSerializers: Set<SerializationCustomSerializer<*,*>>,
|
||||
cordappSerializationWhitelists: Set<SerializationWhitelist>
|
||||
) : this(cordappCustomSerializers, cordappSerializationWhitelists, createDefaultSerializerFactoryCache())
|
||||
|
||||
@Suppress("UNUSED")
|
||||
constructor() : this(emptySet(), emptySet(), AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
|
||||
constructor() : this(emptySet(), emptySet())
|
||||
|
||||
companion object {
|
||||
/** Call from main only. */
|
||||
@ -44,6 +49,10 @@ class AMQPClientSerializationScheme(
|
||||
rpcServerContext = AMQP_RPC_SERVER_CONTEXT
|
||||
)
|
||||
}
|
||||
|
||||
private fun createDefaultSerializerFactoryCache(): MutableMap<SerializationFactoryCacheKey, SerializerFactory> {
|
||||
return AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised()
|
||||
}
|
||||
}
|
||||
|
||||
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
|
||||
|
@ -20,7 +20,7 @@ quasarVersion11=0.8.0_r3_rc1
|
||||
jdkClassifier11=jdk11
|
||||
proguardVersion=6.1.1
|
||||
bouncycastleVersion=1.60
|
||||
classgraphVersion=4.8.41
|
||||
classgraphVersion=4.8.58
|
||||
disruptorVersion=3.4.2
|
||||
typesafeConfigVersion=1.3.4
|
||||
jsr305Version=3.0.2
|
||||
|
@ -76,6 +76,8 @@ dependencies {
|
||||
|
||||
compile group: "io.github.classgraph", name: "classgraph", version: class_graph_version
|
||||
|
||||
testCompile "org.ow2.asm:asm:$asm_version"
|
||||
|
||||
// JDK11: required by Quasar at run-time
|
||||
testRuntimeOnly "com.esotericsoftware:kryo:4.0.2"
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import io.github.classgraph.ClassGraph
|
||||
import io.github.classgraph.ClassInfo
|
||||
import net.corda.core.StubOutForDJVM
|
||||
import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.attachmentScheme
|
||||
|
||||
/**
|
||||
* Creates instances of all the classes in the classpath of the provided classloader, which implement the interface of the provided class.
|
||||
@ -17,16 +19,17 @@ import net.corda.core.StubOutForDJVM
|
||||
*/
|
||||
@StubOutForDJVM
|
||||
fun <T: Any> createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
|
||||
return ClassGraph().addClassLoader(classloader)
|
||||
.enableClassInfo()
|
||||
.pooledScan()
|
||||
.use {
|
||||
it.getClassesImplementing(clazz.name)
|
||||
.filterNot { it.isAbstract }
|
||||
.map { classloader.loadClass(it.name).asSubclass(clazz) }
|
||||
.map { it.kotlin.objectOrNewInstance() }
|
||||
.toSet()
|
||||
}
|
||||
return ClassGraph().overrideClassLoaders(classloader)
|
||||
.enableURLScheme(attachmentScheme)
|
||||
.ignoreParentClassLoaders()
|
||||
.enableClassInfo()
|
||||
.pooledScan()
|
||||
.use { result ->
|
||||
result.getClassesImplementing(clazz.name)
|
||||
.filterNot(ClassInfo::isAbstract)
|
||||
.map { classloader.loadClass(it.name).asSubclass(clazz) }
|
||||
.mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() }
|
||||
}
|
||||
}
|
||||
|
||||
fun <T: Any?> executeWithThreadContextClassLoader(classloader: ClassLoader, fn: () -> T): T {
|
||||
|
@ -351,7 +351,7 @@ object AttachmentsClassLoaderBuilder {
|
||||
* This will not be exposed as an API.
|
||||
*/
|
||||
object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory {
|
||||
private const val attachmentScheme = "attachment"
|
||||
internal const val attachmentScheme = "attachment"
|
||||
|
||||
// TODO - what happens if this grows too large?
|
||||
private val loadedAttachments = mutableMapOf<String, Attachment>().toSynchronised()
|
||||
|
31
core/src/test/kotlin/net/corda/core/internal/AsmTools.kt
Normal file
31
core/src/test/kotlin/net/corda/core/internal/AsmTools.kt
Normal file
@ -0,0 +1,31 @@
|
||||
@file:JvmName("AsmTools")
|
||||
package net.corda.core.internal
|
||||
|
||||
import org.objectweb.asm.ClassReader
|
||||
import org.objectweb.asm.ClassVisitor
|
||||
import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.ClassWriter.COMPUTE_MAXS
|
||||
import org.objectweb.asm.Opcodes.ASM7
|
||||
|
||||
val String.asInternalName: String get() = replace('.', '/')
|
||||
val Class<*>.resourceName: String get() = "${name.asInternalName}.class"
|
||||
val Class<*>.byteCode: ByteArray get() = classLoader.getResourceAsStream(resourceName)!!.use {
|
||||
it.readBytes()
|
||||
}
|
||||
|
||||
fun Class<*>.renameTo(newName: String): ByteArray {
|
||||
return byteCode.accept { w -> RenamingWriter(newName, w) }
|
||||
}
|
||||
|
||||
fun ByteArray.accept(visitor: (ClassVisitor) -> ClassVisitor): ByteArray {
|
||||
return ClassWriter(COMPUTE_MAXS).let { writer ->
|
||||
ClassReader(this).accept(visitor(writer), 0)
|
||||
writer.toByteArray()
|
||||
}
|
||||
}
|
||||
|
||||
private class RenamingWriter(private val newName: String, visitor: ClassVisitor) : ClassVisitor(ASM7, visitor) {
|
||||
override fun visit(version: Int, access: Int, name: String, signature: String?, superName: String?, interfaces: Array<out String>?) {
|
||||
super.visit(version, access, newName, signature, superName, interfaces)
|
||||
}
|
||||
}
|
@ -1,23 +1,62 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import net.corda.core.contracts.ContractAttachment
|
||||
import net.corda.core.contracts.ContractClassName
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory
|
||||
import net.corda.core.serialization.internal.AttachmentsClassLoader
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Test
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.net.URLClassLoader
|
||||
import java.security.PublicKey
|
||||
import java.util.jar.JarOutputStream
|
||||
import java.util.jar.Manifest
|
||||
import java.util.zip.Deflater.NO_COMPRESSION
|
||||
import java.util.zip.ZipEntry
|
||||
import java.util.zip.ZipEntry.DEFLATED
|
||||
import java.util.zip.ZipEntry.STORED
|
||||
|
||||
class ClassLoadingUtilsTest {
|
||||
companion object {
|
||||
const val STANDALONE_CLASS_NAME = "com.example.StandaloneClassWithEmptyConstructor"
|
||||
const val PROGRAM_ID: ContractClassName = "net.corda.core.internal.DummyContract"
|
||||
val contractAttachmentId = SecureHash.randomSHA256()
|
||||
|
||||
fun directoryEntry(internalName: String) = ZipEntry("$internalName/").apply {
|
||||
method = STORED
|
||||
compressedSize = 0
|
||||
size = 0
|
||||
crc = 0
|
||||
}
|
||||
|
||||
fun classEntry(internalName: String) = ZipEntry("$internalName.class").apply {
|
||||
method = DEFLATED
|
||||
}
|
||||
|
||||
init {
|
||||
// Register the "attachment://" URL scheme.
|
||||
// You may only register new schemes ONCE per JVM!
|
||||
AttachmentsClassLoader
|
||||
}
|
||||
}
|
||||
|
||||
private val temporaryClassLoader = mock<ClassLoader>()
|
||||
|
||||
interface BaseInterface {}
|
||||
interface BaseInterface
|
||||
|
||||
interface BaseInterface2 {}
|
||||
interface BaseInterface2
|
||||
|
||||
class ConcreteClassWithEmptyConstructor: BaseInterface {}
|
||||
class ConcreteClassWithEmptyConstructor: BaseInterface
|
||||
|
||||
abstract class AbstractClass: BaseInterface
|
||||
|
||||
class ConcreteClassWithNonEmptyConstructor(private val someData: Int): BaseInterface2 {}
|
||||
@Suppress("unused")
|
||||
class ConcreteClassWithNonEmptyConstructor(private val someData: Int): BaseInterface2
|
||||
|
||||
@Test
|
||||
fun predicateClassAreLoadedSuccessfully() {
|
||||
@ -25,8 +64,9 @@ class ClassLoadingUtilsTest {
|
||||
|
||||
val classNames = classes.map { it.javaClass.name }
|
||||
|
||||
assertThat(classNames).contains(ConcreteClassWithEmptyConstructor::class.java.name)
|
||||
assertThat(classNames).doesNotContain(AbstractClass::class.java.name)
|
||||
assertThat(classNames)
|
||||
.contains(ConcreteClassWithEmptyConstructor::class.java.name)
|
||||
.doesNotContain(AbstractClass::class.java.name)
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException::class)
|
||||
@ -54,4 +94,42 @@ class ClassLoadingUtilsTest {
|
||||
assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader)
|
||||
}
|
||||
|
||||
}
|
||||
@Test
|
||||
fun `test locating classes inside attachment`() {
|
||||
val jarData = with(ByteArrayOutputStream()) {
|
||||
val internalName = STANDALONE_CLASS_NAME.asInternalName
|
||||
JarOutputStream(this, Manifest()).use {
|
||||
it.setLevel(NO_COMPRESSION)
|
||||
it.setMethod(DEFLATED)
|
||||
it.putNextEntry(directoryEntry("com"))
|
||||
it.putNextEntry(directoryEntry("com/example"))
|
||||
it.putNextEntry(classEntry(internalName))
|
||||
it.write(TemplateClassWithEmptyConstructor::class.java.renameTo(internalName))
|
||||
}
|
||||
toByteArray()
|
||||
}
|
||||
val attachment = signedAttachment(jarData)
|
||||
val url = AttachmentURLStreamHandlerFactory.toUrl(attachment)
|
||||
|
||||
URLClassLoader(arrayOf(url)).use { cordappClassLoader ->
|
||||
val standaloneClass = createInstancesOfClassesImplementing(cordappClassLoader, BaseInterface::class.java)
|
||||
.map(Any::javaClass)
|
||||
.single()
|
||||
assertEquals(STANDALONE_CLASS_NAME, standaloneClass.name)
|
||||
assertEquals(cordappClassLoader, standaloneClass.classLoader)
|
||||
}
|
||||
}
|
||||
|
||||
private fun signedAttachment(data: ByteArray, vararg parties: Party) = ContractAttachment.create(
|
||||
object : AbstractAttachment({ data }, "test") {
|
||||
override val id: SecureHash get() = contractAttachmentId
|
||||
|
||||
override val signerKeys: List<PublicKey> get() = parties.map(Party::owningKey)
|
||||
}, PROGRAM_ID, signerKeys = parties.map(Party::owningKey)
|
||||
)
|
||||
}
|
||||
|
||||
// Our dummy attachment will contain a class that is created from this one.
|
||||
// This is because our attachment must contain a class that DOES NOT exist
|
||||
// inside the application classloader.
|
||||
class TemplateClassWithEmptyConstructor : ClassLoadingUtilsTest.BaseInterface
|
||||
|
@ -0,0 +1,21 @@
|
||||
package net.corda.contracts.serialization.custom
|
||||
|
||||
import net.corda.core.serialization.SerializationCustomSerializer
|
||||
|
||||
@Suppress("unused")
|
||||
class CurrantsySerializer : SerializationCustomSerializer<Currantsy, CurrantsySerializer.Proxy> {
|
||||
data class Proxy(val currants: Long)
|
||||
|
||||
override fun fromProxy(proxy: Proxy): Currantsy = Currantsy(proxy.currants)
|
||||
override fun toProxy(obj: Currantsy) = Proxy(obj.currants)
|
||||
}
|
||||
|
||||
data class Currantsy(val currants: Long) : Comparable<Currantsy> {
|
||||
override fun toString(): String {
|
||||
return "$currants juicy currants"
|
||||
}
|
||||
|
||||
override fun compareTo(other: Currantsy): Int {
|
||||
return currants.compareTo(other.currants)
|
||||
}
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
package net.corda.contracts.serialization.custom
|
||||
|
||||
import net.corda.core.contracts.CommandData
|
||||
import net.corda.core.contracts.Contract
|
||||
import net.corda.core.contracts.ContractState
|
||||
import net.corda.core.identity.AbstractParty
|
||||
import net.corda.core.transactions.LedgerTransaction
|
||||
|
||||
@Suppress("unused")
|
||||
class CustomSerializerContract : Contract {
|
||||
companion object {
|
||||
const val MAX_CURRANT = 2000
|
||||
}
|
||||
|
||||
override fun verify(tx: LedgerTransaction) {
|
||||
val currantsyData = tx.outputsOfType(CurrantsyState::class.java)
|
||||
require(currantsyData.isNotEmpty()) {
|
||||
"Requires at least one currantsy state"
|
||||
}
|
||||
|
||||
currantsyData.forEach {
|
||||
require(it.currantsy.currants in 0..MAX_CURRANT) {
|
||||
"Too many currants! ${it.currantsy.currants} is unraisinable!"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("CanBeParameter", "MemberVisibilityCanBePrivate")
|
||||
class CurrantsyState(val owner: AbstractParty, val currantsy: Currantsy) : ContractState {
|
||||
override val participants: List<AbstractParty> = listOf(owner)
|
||||
|
||||
@Override
|
||||
override fun toString(): String {
|
||||
return currantsy.toString()
|
||||
}
|
||||
}
|
||||
|
||||
class Purchase : CommandData
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package net.corda.contracts.serialization.custom
|
||||
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
|
||||
class CustomSerializerRegistry : SerializationWhitelist {
|
||||
override val whitelist: List<Class<*>> = listOf(Currantsy::class.java)
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package net.corda.flows.serialization.custom
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.contracts.serialization.custom.Currantsy
|
||||
import net.corda.contracts.serialization.custom.CustomSerializerContract.CurrantsyState
|
||||
import net.corda.contracts.serialization.custom.CustomSerializerContract.Purchase
|
||||
import net.corda.core.contracts.Command
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.flows.StartableByRPC
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
|
||||
@InitiatingFlow
|
||||
@StartableByRPC
|
||||
class CustomSerializerFlow(
|
||||
private val purchase: Currantsy
|
||||
) : FlowLogic<SecureHash>() {
|
||||
@Suspendable
|
||||
override fun call(): SecureHash {
|
||||
val notary = serviceHub.networkMapCache.notaryIdentities[0]
|
||||
val stx = serviceHub.signInitialTransaction(
|
||||
TransactionBuilder(notary)
|
||||
.addOutputState(CurrantsyState(ourIdentity, purchase))
|
||||
.addCommand(Command(Purchase(), ourIdentity.owningKey))
|
||||
)
|
||||
stx.verify(serviceHub, checkSufficientSignatures = false)
|
||||
return stx.id
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package net.corda.node
|
||||
|
||||
import net.corda.client.rpc.CordaRPCClient
|
||||
import net.corda.contracts.serialization.custom.Currantsy
|
||||
import net.corda.core.contracts.TransactionVerificationException
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.flows.serialization.custom.CustomSerializerFlow
|
||||
import net.corda.node.services.Permissions
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.DUMMY_NOTARY_NAME
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.driver
|
||||
import net.corda.testing.driver.internal.incrementalPortAllocation
|
||||
import net.corda.testing.node.NotarySpec
|
||||
import net.corda.testing.node.User
|
||||
import net.corda.testing.node.internal.CustomCordapp
|
||||
import net.corda.testing.node.internal.cordappWithPackages
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
import org.junit.jupiter.api.assertThrows
|
||||
|
||||
@Suppress("FunctionName")
|
||||
class ContractWithCustomSerializerTest {
|
||||
companion object {
|
||||
const val CURRANTS = 5000L
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `flow with custom serializer by rpc`() {
|
||||
val user = User("u", "p", setOf(Permissions.all()))
|
||||
driver(DriverParameters(
|
||||
portAllocation = incrementalPortAllocation(),
|
||||
startNodesInProcess = false,
|
||||
notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)),
|
||||
cordappsForAllNodes = listOf(
|
||||
cordappWithPackages("net.corda.flows.serialization.custom"),
|
||||
CustomCordapp(
|
||||
packages = setOf("net.corda.contracts.serialization.custom"),
|
||||
name = "has-custom-serializer"
|
||||
).signed()
|
||||
)
|
||||
)) {
|
||||
val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow()
|
||||
val ex = assertThrows<TransactionVerificationException> {
|
||||
CordaRPCClient(hostAndPort = alice.rpcAddress)
|
||||
.start(user.username, user.password)
|
||||
.proxy
|
||||
.startFlow(::CustomSerializerFlow, Currantsy(CURRANTS))
|
||||
.returnValue
|
||||
.getOrThrow()
|
||||
}
|
||||
assertThat(ex).hasMessageContaining("Too many currants! $CURRANTS is unraisinable!")
|
||||
}
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@ import io.github.classgraph.ClassGraph
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.cordapp.CordappImpl
|
||||
import net.corda.core.internal.cordapp.set
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.testing.core.internal.JarSignatureTestUtils.generateKey
|
||||
@ -58,11 +59,20 @@ data class CustomCordapp(
|
||||
classGraph.whitelistClasses(*classes.map { it.name }.toTypedArray())
|
||||
}
|
||||
|
||||
classGraph.pooledScan().use { scanResult ->
|
||||
classGraph.enableClassInfo().pooledScan().use { scanResult ->
|
||||
val whitelistService = SerializationWhitelist::class.java.name
|
||||
val whitelists = scanResult.getClassesImplementing(whitelistService)
|
||||
|
||||
JarOutputStream(file.outputStream()).use { jos ->
|
||||
jos.addEntry(testEntry(JarFile.MANIFEST_NAME)) {
|
||||
createTestManifest(name, versionId, targetPlatformVersion).write(jos)
|
||||
}
|
||||
if (whitelists.isNotEmpty()) {
|
||||
jos.addEntry(directoryEntry("META-INF/services"))
|
||||
jos.addEntry(testEntry("META-INF/services/$whitelistService")) {
|
||||
jos.write(whitelists.names.joinToString(separator = "\r\n").toByteArray())
|
||||
}
|
||||
}
|
||||
|
||||
// The same resource may be found in different locations (this will happen when running from gradle) so just
|
||||
// pick the first one found.
|
||||
@ -114,6 +124,16 @@ data class CustomCordapp(
|
||||
return ZipEntry(name).setCreationTime(epochFileTime).setLastAccessTime(epochFileTime).setLastModifiedTime(epochFileTime)
|
||||
}
|
||||
|
||||
private fun directoryEntry(name: String): ZipEntry {
|
||||
val directoryName = if (name.endsWith('/')) name else "$name/"
|
||||
return testEntry(directoryName).apply {
|
||||
method = ZipEntry.STORED
|
||||
compressedSize = 0
|
||||
size = 0
|
||||
crc = 0
|
||||
}
|
||||
}
|
||||
|
||||
data class SigningInfo(val keyStorePath: Path?, val numberOfSignatures: Int, val keyAlgorithm: String)
|
||||
|
||||
companion object {
|
||||
|
@ -35,7 +35,6 @@ import net.corda.core.node.NetworkParameters
|
||||
import net.corda.core.node.NotaryInfo
|
||||
import net.corda.core.node.services.NetworkMapCache
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.millis
|
||||
@ -84,6 +83,7 @@ import rx.schedulers.Schedulers
|
||||
import java.io.File
|
||||
import java.net.ConnectException
|
||||
import java.net.URL
|
||||
import java.net.URLClassLoader
|
||||
import java.nio.file.Path
|
||||
import java.security.cert.X509Certificate
|
||||
import java.time.Duration
|
||||
@ -142,6 +142,12 @@ class DriverDSLImpl(
|
||||
private lateinit var _notaries: CordaFuture<List<NotaryHandle>>
|
||||
override val notaryHandles: List<NotaryHandle> get() = _notaries.getOrThrow()
|
||||
|
||||
override val cordappsClassLoader: ClassLoader? = if (!startNodesInProcess) {
|
||||
createCordappsClassLoader(cordappsForAllNodes)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
interface Waitable {
|
||||
@Throws(InterruptedException::class)
|
||||
fun waitFor()
|
||||
@ -193,6 +199,7 @@ class DriverDSLImpl(
|
||||
}
|
||||
_shutdownManager?.shutdown()
|
||||
_executorService?.shutdownNow()
|
||||
(cordappsClassLoader as? AutoCloseable)?.close()
|
||||
}
|
||||
|
||||
private fun establishRpc(config: NodeConfig, processDeathFuture: CordaFuture<out Process>): CordaFuture<CordaRPCOps> {
|
||||
@ -982,6 +989,13 @@ class DriverDSLImpl(
|
||||
return config
|
||||
}
|
||||
|
||||
private fun createCordappsClassLoader(cordapps: Collection<TestCordappInternal>?): ClassLoader? {
|
||||
if (cordapps == null || cordapps.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
return URLClassLoader(cordapps.map { it.jarFile.toUri().toURL() }.toTypedArray())
|
||||
}
|
||||
|
||||
private operator fun Config.plus(property: Pair<String, Any>) = withValue(property.first, ConfigValueFactory.fromAnyRef(property.second))
|
||||
|
||||
/**
|
||||
@ -1069,6 +1083,8 @@ interface InternalDriverDSL : DriverDSL {
|
||||
|
||||
val shutdownManager: ShutdownManager
|
||||
|
||||
val cordappsClassLoader: ClassLoader?
|
||||
|
||||
fun baseDirectory(nodeName: String): Path = baseDirectory(CordaX500Name.parse(nodeName))
|
||||
|
||||
/**
|
||||
@ -1113,7 +1129,7 @@ fun <DI : DriverDSL, D : InternalDriverDSL, A> genericDriver(
|
||||
coerce: (D) -> DI,
|
||||
dsl: DI.() -> A
|
||||
): A {
|
||||
val serializationEnv = setDriverSerialization()
|
||||
val serializationEnv = setDriverSerialization(driverDsl.cordappsClassLoader)
|
||||
val shutdownHook = addShutdownHook(driverDsl::shutdown)
|
||||
try {
|
||||
driverDsl.start()
|
||||
|
@ -282,15 +282,18 @@ fun DriverDSL.assertCheckpoints(name: CordaX500Name, expected: Long) {
|
||||
/**
|
||||
* Should only be used by Driver and MockNode.
|
||||
*/
|
||||
fun setDriverSerialization(): AutoCloseable? {
|
||||
fun setDriverSerialization(classLoader: ClassLoader?): AutoCloseable? {
|
||||
return if (_allEnabledSerializationEnvs.isEmpty()) {
|
||||
DriverSerializationEnvironment().enable()
|
||||
DriverSerializationEnvironment(classLoader).enable()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private class DriverSerializationEnvironment : SerializationEnvironment by createTestSerializationEnv(), AutoCloseable {
|
||||
fun setDriverSerialization(): AutoCloseable? = setDriverSerialization(null)
|
||||
|
||||
private class DriverSerializationEnvironment(classLoader: ClassLoader?)
|
||||
: SerializationEnvironment by createTestSerializationEnv(classLoader), AutoCloseable {
|
||||
fun enable() = apply { _driverSerializationEnv.set(this) }
|
||||
override fun close() {
|
||||
_driverSerializationEnv.set(null)
|
||||
@ -303,6 +306,8 @@ fun JarOutputStream.addEntry(entry: ZipEntry, input: InputStream) {
|
||||
addEntry(entry) { input.use { it.copyTo(this) } }
|
||||
}
|
||||
|
||||
fun JarOutputStream.addEntry(entry: ZipEntry) = addEntry(entry) {}
|
||||
|
||||
inline fun JarOutputStream.addEntry(entry: ZipEntry, write: () -> Unit) {
|
||||
putNextEntry(entry)
|
||||
write()
|
||||
|
@ -1,20 +1,35 @@
|
||||
package net.corda.testing.internal
|
||||
|
||||
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
|
||||
import net.corda.core.internal.createInstancesOfClassesImplementing
|
||||
import net.corda.core.serialization.SerializationCustomSerializer
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
import net.corda.core.serialization.internal.SerializationEnvironment
|
||||
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
|
||||
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
|
||||
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
|
||||
import net.corda.serialization.internal.*
|
||||
import net.corda.testing.common.internal.asContextEnv
|
||||
import java.util.ServiceLoader
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
|
||||
val inVMExecutors = ConcurrentHashMap<SerializationEnvironment, ExecutorService>()
|
||||
|
||||
fun createTestSerializationEnv(): SerializationEnvironment {
|
||||
return createTestSerializationEnv(null)
|
||||
}
|
||||
|
||||
fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment {
|
||||
val clientSerializationScheme = if (classLoader != null) {
|
||||
val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java)
|
||||
val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet()
|
||||
AMQPClientSerializationScheme(customSerializers, serializationWhitelists)
|
||||
} else {
|
||||
AMQPClientSerializationScheme(emptyList())
|
||||
}
|
||||
val factory = SerializationFactoryImpl().apply {
|
||||
registerScheme(AMQPClientSerializationScheme(emptyList()))
|
||||
registerScheme(clientSerializationScheme)
|
||||
registerScheme(AMQPServerSerializationScheme(emptyList()))
|
||||
}
|
||||
return SerializationEnvironment.with(
|
||||
|
Loading…
Reference in New Issue
Block a user