From 492f54d6fe3c03c40a88e1995d8fb350c4a98c49 Mon Sep 17 00:00:00 2001 From: Dimos Raptis Date: Mon, 4 Feb 2019 14:55:09 +0000 Subject: [PATCH] [CORDA-2422] - Remove interfaces from carpenting (#4687) Remove interfaces from carpenting. --- .../serialization/amqp/SerializerFactory.kt | 36 +++++++++++++--- .../amqp/DeserializeNeedingCarpentryTests.kt | 22 ---------- ...ticInitialisationOfSerializedObjectTest.kt | 41 +++++++++++++++---- 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index ce33b64c41..efc98b2bab 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.serialization.amqp import com.google.common.primitives.Primitives import com.google.common.reflect.TypeResolver +import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.ClassWhitelist import net.corda.nodeapi.internal.serialization.carpenter.CarpenterMetaSchema @@ -44,7 +45,8 @@ open class SerializerFactory( cl: ClassLoader, private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter() ) { - private val serializersByType = ConcurrentHashMap>() + @VisibleForTesting + internal val serializersByType = ConcurrentHashMap>() private val serializersByDescriptor = ConcurrentHashMap>() private val customSerializers = CopyOnWriteArrayList() private val transformsCache = ConcurrentHashMap>>() @@ -216,16 +218,38 @@ open class SerializerFactory( /** * Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and, - * if not, use the [ClassCarpenter] to generate a class to use in it's place. + * if not, use the [ClassCarpenter] to generate a class to use in its place. + * + * The processing of the schema is performed in the following steps: + * - All the (non-interface) types are attempted to be loaded from the current classpath. + * - For any of those types that cannot be found in the current classpath: + * - The associated interfaces are loaded from the classpath. + * - These types are added to the CarpenterMetaSchema, which contains everything in need of carpenting + * + * As a result, interfaces are only loaded on-demand, according to the needs for carpenting. + * This is done in order to preserve backwards compatibility, in cases where 2 nodes communicate and one of the transported classes + * implements an interface that one of them is unaware of (i.e. introduced by a subsequent version). In this case, this node is not + * expected to make use of this interface anyway, since the associated CorDapps will be developed in versions that do not contain it, + * so it should not attempt to load it all. */ private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor) { + val schemaTypes = schemaAndDescriptor.schemas.schema.types + val interfacesPerClass = schemaTypes.associateBy({it.name}, + {type -> schemaTypes.filter { it.name in type.provides }} + ) + val allInterfaceNames = interfacesPerClass.values.asSequence().flatten().map { it.name } + val metaSchema = CarpenterMetaSchema.newInstance() - val notationByName = schemaAndDescriptor.schemas.schema.types.associate { it.name to it } - val noCarpentryRequired = notationByName.mapNotNull { (name, notation) -> + val notationByNameForNonInterfaceTypes = schemaTypes + .filterNot { it.name in allInterfaceNames } + .associateBy({it.name}, {it}) + val noCarpentryRequired = notationByNameForNonInterfaceTypes.mapNotNull { (name, notation) -> try { logger.debug { "descriptor=${schemaAndDescriptor.typeDescriptor}, typeNotation=$name" } name to processSchemaEntry(notation) } catch (e: ClassNotFoundException) { + // class missing from the classpath, so load its interfaces and add it for carpenting (see method docs). + interfacesPerClass[name]!!.forEach { processSchemaEntry(it) } metaSchema.buildFor(notation, classloader) null } @@ -236,14 +260,14 @@ open class SerializerFactory( mc.build() } - val carpented = notationByName.minus(noCarpentryRequired.keys).mapValues { (name, notation) -> + val carpented = notationByNameForNonInterfaceTypes.minus(noCarpentryRequired.keys).mapValues { (name, notation) -> processSchemaEntry(notation) } val allLocalSerializers = noCarpentryRequired + carpented allLocalSerializers.forEach { (name, serializer) -> - val typeNotation = notationByName[name]!! + val typeNotation = notationByNameForNonInterfaceTypes[name]!! if (serializer.typeDescriptor != typeNotation.descriptor.name ) { getEvolutionSerializer(typeNotation, serializer, schemaAndDescriptor.schemas) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt index 37d4a22380..466f7b0f93 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt @@ -232,26 +232,4 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { } } - @Test - fun unknownInterface() { - val cc = ClassCarpenter(whitelist = AllWhitelist) - - val interfaceClass = cc.build(InterfaceSchema( - "gen.Interface", - mapOf("age" to NonNullableField(Int::class.java)))) - - val concreteClass = cc.build(ClassSchema(testName(), mapOf( - "age" to NonNullableField(Int::class.java), - "name" to NonNullableField(String::class.java)), - interfaces = listOf(I::class.java, interfaceClass))) - - val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize( - concreteClass.constructors.first().newInstance(12, "timmy")) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) - - assertTrue(deserializedObj is I) - assertEquals("timmy", (deserializedObj as I).getName()) - assertEquals("timmy", deserializedObj::class.java.getMethod("getName").invoke(deserializedObj)) - assertEquals(12, deserializedObj::class.java.getMethod("getAge").invoke(deserializedObj)) - } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt index 71925aa1da..64b83a7525 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt @@ -1,15 +1,17 @@ package net.corda.nodeapi.internal.serialization.amqp +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.AttachmentConstraint import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.SerializedBytes import net.corda.nodeapi.internal.serialization.AllWhitelist +import net.corda.nodeapi.internal.serialization.amqp.testutils.TestSerializationOutput import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter import org.assertj.core.api.Assertions +import org.assertj.core.api.Assertions.assertThat import org.junit.Test import java.io.File import java.io.NotSerializableException -import java.lang.reflect.Type -import java.util.concurrent.ConcurrentHashMap import kotlin.test.assertEquals class InStatic : Exception("Help!, help!, I'm being repressed") @@ -50,14 +52,10 @@ class StaticInitialisationOfSerializedObjectTest { val sf = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) - val typeMap = sf::class.java.getDeclaredField("serializersByType") - typeMap.isAccessible = true - - @Suppress("UNCHECKED_CAST") - val serialisersByType = typeMap.get(sf) as ConcurrentHashMap> + val serializersByType = sf.serializersByType // pre building a serializer, we shouldn't have anything registered - assertEquals(0, serialisersByType.size) + assertEquals(0, serializersByType.size) // build a serializer for type D without an instance of it to serialise, since // we can't actually construct one @@ -65,7 +63,31 @@ class StaticInitialisationOfSerializedObjectTest { // post creation of the serializer we should have one element in the map, this // proves we didn't statically construct an instance of C when building the serializer - assertEquals(1, serialisersByType.size) + assertEquals(1, serializersByType.size) + } + + @Test + fun interfacesAreNotLoadedWhenNotNeeded() { + data class DummyClass(val c: Int): AttachmentConstraint { + override fun isSatisfiedBy(attachment: Attachment): Boolean = true + } + + val schemaForClass = TestSerializationOutput(EnumEvolvabilityTests.VERBOSE).serializeAndReturnSchema(DummyClass(2)).schema + val schemaTypes = schemaForClass.types + val classType = schemaTypes.find { it.name.contains("DummyClass") }!! + val interfaceType = schemaTypes.find { it.name.contains("AttachmentConstraint") }!! + val schemas = SerializationSchemas(schemaForClass, TransformsSchema(emptyMap())) + + val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val serializersByType = factory.serializersByType + + factory.get(classType.descriptor.name!!, schemas) + + // Class D is in the classpath (no need to carpent it), so the interface should not be loaded + val loadedTypes = serializersByType.keys().toList().map { it.typeName } + assertThat(loadedTypes) + .contains(classType.name) + .doesNotContain(interfaceType.name) } @@ -143,4 +165,5 @@ class StaticInitialisationOfSerializedObjectTest { DeserializationInput(sf2).deserialize(SerializedBytes(bytes)) }.isInstanceOf(NotSerializableException::class.java) } + }