[CORDA-2422] - Remove interfaces from carpenting (#4687)

Remove interfaces from carpenting.
This commit is contained in:
Dimos Raptis 2019-02-04 14:55:09 +00:00 committed by Mike Hearn
parent 5c2a7ed72e
commit 492f54d6fe
3 changed files with 62 additions and 37 deletions

View File

@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.primitives.Primitives import com.google.common.primitives.Primitives
import com.google.common.reflect.TypeResolver import com.google.common.reflect.TypeResolver
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.nodeapi.internal.serialization.carpenter.CarpenterMetaSchema import net.corda.nodeapi.internal.serialization.carpenter.CarpenterMetaSchema
@ -44,7 +45,8 @@ open class SerializerFactory(
cl: ClassLoader, cl: ClassLoader,
private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter() private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter()
) { ) {
private val serializersByType = ConcurrentHashMap<Type, AMQPSerializer<Any>>() @VisibleForTesting
internal val serializersByType = ConcurrentHashMap<Type, AMQPSerializer<Any>>()
private val serializersByDescriptor = ConcurrentHashMap<Any, AMQPSerializer<Any>>() private val serializersByDescriptor = ConcurrentHashMap<Any, AMQPSerializer<Any>>()
private val customSerializers = CopyOnWriteArrayList<SerializerFor>() private val customSerializers = CopyOnWriteArrayList<SerializerFor>()
private val transformsCache = ConcurrentHashMap<String, EnumMap<TransformTypes, MutableList<Transform>>>() private val transformsCache = ConcurrentHashMap<String, EnumMap<TransformTypes, MutableList<Transform>>>()
@ -216,16 +218,38 @@ open class SerializerFactory(
/** /**
* Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and, * Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and,
* if not, use the [ClassCarpenter] to generate a class to use in it's place. * if not, use the [ClassCarpenter] to generate a class to use in its place.
*
* The processing of the schema is performed in the following steps:
* - All the (non-interface) types are attempted to be loaded from the current classpath.
* - For any of those types that cannot be found in the current classpath:
* - The associated interfaces are loaded from the classpath.
* - These types are added to the CarpenterMetaSchema, which contains everything in need of carpenting
*
* As a result, interfaces are only loaded on-demand, according to the needs for carpenting.
* This is done in order to preserve backwards compatibility, in cases where 2 nodes communicate and one of the transported classes
* implements an interface that one of them is unaware of (i.e. introduced by a subsequent version). In this case, this node is not
* expected to make use of this interface anyway, since the associated CorDapps will be developed in versions that do not contain it,
* so it should not attempt to load it all.
*/ */
private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor) { private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor) {
val schemaTypes = schemaAndDescriptor.schemas.schema.types
val interfacesPerClass = schemaTypes.associateBy({it.name},
{type -> schemaTypes.filter { it.name in type.provides }}
)
val allInterfaceNames = interfacesPerClass.values.asSequence().flatten().map { it.name }
val metaSchema = CarpenterMetaSchema.newInstance() val metaSchema = CarpenterMetaSchema.newInstance()
val notationByName = schemaAndDescriptor.schemas.schema.types.associate { it.name to it } val notationByNameForNonInterfaceTypes = schemaTypes
val noCarpentryRequired = notationByName.mapNotNull { (name, notation) -> .filterNot { it.name in allInterfaceNames }
.associateBy({it.name}, {it})
val noCarpentryRequired = notationByNameForNonInterfaceTypes.mapNotNull { (name, notation) ->
try { try {
logger.debug { "descriptor=${schemaAndDescriptor.typeDescriptor}, typeNotation=$name" } logger.debug { "descriptor=${schemaAndDescriptor.typeDescriptor}, typeNotation=$name" }
name to processSchemaEntry(notation) name to processSchemaEntry(notation)
} catch (e: ClassNotFoundException) { } catch (e: ClassNotFoundException) {
// class missing from the classpath, so load its interfaces and add it for carpenting (see method docs).
interfacesPerClass[name]!!.forEach { processSchemaEntry(it) }
metaSchema.buildFor(notation, classloader) metaSchema.buildFor(notation, classloader)
null null
} }
@ -236,14 +260,14 @@ open class SerializerFactory(
mc.build() mc.build()
} }
val carpented = notationByName.minus(noCarpentryRequired.keys).mapValues { (name, notation) -> val carpented = notationByNameForNonInterfaceTypes.minus(noCarpentryRequired.keys).mapValues { (name, notation) ->
processSchemaEntry(notation) processSchemaEntry(notation)
} }
val allLocalSerializers = noCarpentryRequired + carpented val allLocalSerializers = noCarpentryRequired + carpented
allLocalSerializers.forEach { (name, serializer) -> allLocalSerializers.forEach { (name, serializer) ->
val typeNotation = notationByName[name]!! val typeNotation = notationByNameForNonInterfaceTypes[name]!!
if (serializer.typeDescriptor != typeNotation.descriptor.name ) { if (serializer.typeDescriptor != typeNotation.descriptor.name ) {
getEvolutionSerializer(typeNotation, serializer, schemaAndDescriptor.schemas) getEvolutionSerializer(typeNotation, serializer, schemaAndDescriptor.schemas)
} }

View File

@ -232,26 +232,4 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) {
} }
} }
@Test
fun unknownInterface() {
val cc = ClassCarpenter(whitelist = AllWhitelist)
val interfaceClass = cc.build(InterfaceSchema(
"gen.Interface",
mapOf("age" to NonNullableField(Int::class.java))))
val concreteClass = cc.build(ClassSchema(testName(), mapOf(
"age" to NonNullableField(Int::class.java),
"name" to NonNullableField(String::class.java)),
interfaces = listOf(I::class.java, interfaceClass)))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(
concreteClass.constructors.first().newInstance(12, "timmy"))
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertTrue(deserializedObj is I)
assertEquals("timmy", (deserializedObj as I).getName())
assertEquals("timmy", deserializedObj::class.java.getMethod("getName").invoke(deserializedObj))
assertEquals(12, deserializedObj::class.java.getMethod("getAge").invoke(deserializedObj))
}
} }

View File

@ -1,15 +1,17 @@
package net.corda.nodeapi.internal.serialization.amqp package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.AttachmentConstraint
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.nodeapi.internal.serialization.AllWhitelist import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.amqp.testutils.TestSerializationOutput
import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test import org.junit.Test
import java.io.File import java.io.File
import java.io.NotSerializableException import java.io.NotSerializableException
import java.lang.reflect.Type
import java.util.concurrent.ConcurrentHashMap
import kotlin.test.assertEquals import kotlin.test.assertEquals
class InStatic : Exception("Help!, help!, I'm being repressed") class InStatic : Exception("Help!, help!, I'm being repressed")
@ -50,14 +52,10 @@ class StaticInitialisationOfSerializedObjectTest {
val sf = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) val sf = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
val typeMap = sf::class.java.getDeclaredField("serializersByType") val serializersByType = sf.serializersByType
typeMap.isAccessible = true
@Suppress("UNCHECKED_CAST")
val serialisersByType = typeMap.get(sf) as ConcurrentHashMap<Type, AMQPSerializer<Any>>
// pre building a serializer, we shouldn't have anything registered // pre building a serializer, we shouldn't have anything registered
assertEquals(0, serialisersByType.size) assertEquals(0, serializersByType.size)
// build a serializer for type D without an instance of it to serialise, since // build a serializer for type D without an instance of it to serialise, since
// we can't actually construct one // we can't actually construct one
@ -65,7 +63,31 @@ class StaticInitialisationOfSerializedObjectTest {
// post creation of the serializer we should have one element in the map, this // post creation of the serializer we should have one element in the map, this
// proves we didn't statically construct an instance of C when building the serializer // proves we didn't statically construct an instance of C when building the serializer
assertEquals(1, serialisersByType.size) assertEquals(1, serializersByType.size)
}
@Test
fun interfacesAreNotLoadedWhenNotNeeded() {
data class DummyClass(val c: Int): AttachmentConstraint {
override fun isSatisfiedBy(attachment: Attachment): Boolean = true
}
val schemaForClass = TestSerializationOutput(EnumEvolvabilityTests.VERBOSE).serializeAndReturnSchema(DummyClass(2)).schema
val schemaTypes = schemaForClass.types
val classType = schemaTypes.find { it.name.contains("DummyClass") }!!
val interfaceType = schemaTypes.find { it.name.contains("AttachmentConstraint") }!!
val schemas = SerializationSchemas(schemaForClass, TransformsSchema(emptyMap()))
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
val serializersByType = factory.serializersByType
factory.get(classType.descriptor.name!!, schemas)
// Class D is in the classpath (no need to carpent it), so the interface should not be loaded
val loadedTypes = serializersByType.keys().toList().map { it.typeName }
assertThat(loadedTypes)
.contains(classType.name)
.doesNotContain(interfaceType.name)
} }
@ -143,4 +165,5 @@ class StaticInitialisationOfSerializedObjectTest {
DeserializationInput(sf2).deserialize(SerializedBytes<D>(bytes)) DeserializationInput(sf2).deserialize(SerializedBytes<D>(bytes))
}.isInstanceOf(NotSerializableException::class.java) }.isInstanceOf(NotSerializableException::class.java)
} }
} }