ENT-1463: Isolate more non-deterministic code from AMQP serialisation. (#3138)

This commit is contained in:
Chris Rankin 2018-05-14 16:50:43 +01:00 committed by GitHub
parent 6e59a694c1
commit 4f9bbc8820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 58 additions and 49 deletions

View File

@ -58,7 +58,7 @@ abstract class CustomSerializer<T : Any> : AMQPSerializer<T>, SerializerFor {
* subclass in the schema, so that we can distinguish between subclasses. * subclass in the schema, so that we can distinguish between subclasses.
*/ */
// TODO: should this be a custom serializer at all, or should it just be a plain AMQPSerializer? // TODO: should this be a custom serializer at all, or should it just be a plain AMQPSerializer?
class SubClass<T : Any>(protected val clazz: Class<*>, protected val superClassSerializer: CustomSerializer<T>) : CustomSerializer<T>() { class SubClass<T : Any>(private val clazz: Class<*>, private val superClassSerializer: CustomSerializer<T>) : CustomSerializer<T>() {
// TODO: should this be empty or contain the schema of the super? // TODO: should this be empty or contain the schema of the super?
override val schemaForDocumentation = Schema(emptyList()) override val schemaForDocumentation = Schema(emptyList())

View File

@ -239,7 +239,7 @@ class EvolutionSerializerGetter : EvolutionSerializerGetterBase() {
typeNotation: TypeNotation, typeNotation: TypeNotation,
newSerializer: AMQPSerializer<Any>, newSerializer: AMQPSerializer<Any>,
schemas: SerializationSchemas): AMQPSerializer<Any> { schemas: SerializationSchemas): AMQPSerializer<Any> {
return factory.getSerializersByDescriptor().computeIfAbsent(typeNotation.descriptor.name!!) { return factory.serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) {
when (typeNotation) { when (typeNotation) {
is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, factory) is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, factory)
is RestrictedType -> { is RestrictedType -> {

View File

@ -40,32 +40,40 @@ open class SerializerFactory(
val whitelist: ClassWhitelist, val whitelist: ClassWhitelist,
val classCarpenter: ClassCarpenter, val classCarpenter: ClassCarpenter,
private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(), private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(),
val fingerPrinter: FingerPrinter = SerializerFingerPrinter()) { val fingerPrinter: FingerPrinter = SerializerFingerPrinter(),
private val serializersByType: MutableMap<Type, AMQPSerializer<Any>>,
val serializersByDescriptor: MutableMap<Any, AMQPSerializer<Any>>,
private val customSerializers: MutableList<SerializerFor>,
val transformsCache: MutableMap<String, EnumMap<TransformTypes, MutableList<Transform>>>) {
constructor(whitelist: ClassWhitelist,
classCarpenter: ClassCarpenter,
evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(),
fingerPrinter: FingerPrinter = SerializerFingerPrinter()
) : this(whitelist, classCarpenter, evolutionSerializerGetter, fingerPrinter,
serializersByType = ConcurrentHashMap(),
serializersByDescriptor = ConcurrentHashMap(),
customSerializers = CopyOnWriteArrayList(),
transformsCache = ConcurrentHashMap())
constructor(whitelist: ClassWhitelist, constructor(whitelist: ClassWhitelist,
classLoader: ClassLoader, classLoader: ClassLoader,
evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(), evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(),
fingerPrinter: FingerPrinter = SerializerFingerPrinter() fingerPrinter: FingerPrinter = SerializerFingerPrinter()
) : this(whitelist, ClassCarpenterImpl(classLoader, whitelist), evolutionSerializerGetter, fingerPrinter) ) : this(whitelist, ClassCarpenterImpl(classLoader, whitelist), evolutionSerializerGetter, fingerPrinter,
serializersByType = ConcurrentHashMap(),
serializersByDescriptor = ConcurrentHashMap(),
customSerializers = CopyOnWriteArrayList(),
transformsCache = ConcurrentHashMap())
init { init {
fingerPrinter.setOwner(this) fingerPrinter.setOwner(this)
} }
private val serializersByType = ConcurrentHashMap<Type, AMQPSerializer<Any>>()
private val serializersByDescriptor = ConcurrentHashMap<Any, AMQPSerializer<Any>>()
private val customSerializers = CopyOnWriteArrayList<SerializerFor>()
private val transformsCache = ConcurrentHashMap<String, EnumMap<TransformTypes, MutableList<Transform>>>()
val classloader: ClassLoader val classloader: ClassLoader
get() = classCarpenter.classloader get() = classCarpenter.classloader
private fun getEvolutionSerializer(typeNotation: TypeNotation, newSerializer: AMQPSerializer<Any>, private fun getEvolutionSerializer(typeNotation: TypeNotation, newSerializer: AMQPSerializer<Any>,
schemas: SerializationSchemas) = evolutionSerializerGetter.getEvolutionSerializer(this, typeNotation, newSerializer, schemas) schemas: SerializationSchemas) = evolutionSerializerGetter.getEvolutionSerializer(this, typeNotation, newSerializer, schemas)
fun getSerializersByDescriptor() = serializersByDescriptor
fun getTransformsCache() = transformsCache
/** /**
* Look up, and manufacture if necessary, a serializer for the given type. * Look up, and manufacture if necessary, a serializer for the given type.
* *
@ -219,7 +227,7 @@ open class SerializerFactory(
/** /**
* Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and, * Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and,
* if not, use the [ClassCarpenter] to generate a class to use in it's place. * if not, use the [ClassCarpenter] to generate a class to use in its place.
*/ */
private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor, sentinel: Boolean = false) { private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor, sentinel: Boolean = false) {
val metaSchema = CarpenterMetaSchema.newInstance() val metaSchema = CarpenterMetaSchema.newInstance()
@ -239,24 +247,28 @@ open class SerializerFactory(
} }
if (metaSchema.isNotEmpty()) { if (metaSchema.isNotEmpty()) {
val mc = MetaCarpenter(metaSchema, classCarpenter) runCarpentry(schemaAndDescriptor, metaSchema)
try {
mc.build()
} catch (e: MetaCarpenterException) {
// preserve the actual message locally
loggerFor<SerializerFactory>().apply {
error("${e.message} [hint: enable trace debugging for the stack trace]")
trace("", e)
}
// prevent carpenter exceptions escaping into the world, convert things into a nice
// NotSerializableException for when this escapes over the wire
throw NotSerializableException(e.name)
}
processSchema(schemaAndDescriptor, true)
} }
} }
private fun runCarpentry(schemaAndDescriptor: FactorySchemaAndDescriptor, metaSchema: CarpenterMetaSchema) {
val mc = MetaCarpenter(metaSchema, classCarpenter)
try {
mc.build()
} catch (e: MetaCarpenterException) {
// preserve the actual message locally
loggerFor<SerializerFactory>().apply {
error("${e.message} [hint: enable trace debugging for the stack trace]")
trace("", e)
}
// prevent carpenter exceptions escaping into the world, convert things into a nice
// NotSerializableException for when this escapes over the wire
throw NotSerializableException(e.name)
}
processSchema(schemaAndDescriptor, true)
}
private fun processSchemaEntry(typeNotation: TypeNotation) = when (typeNotation) { private fun processSchemaEntry(typeNotation: TypeNotation) = when (typeNotation) {
is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface) is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface)
is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics

View File

@ -200,7 +200,7 @@ data class TransformsSchema(val types: Map<String, EnumMap<TransformTypes, Mutab
* @param sf the [SerializerFactory] building this transform set. Needed as each can define it's own * @param sf the [SerializerFactory] building this transform set. Needed as each can define it's own
* class loader and this dictates which classes we can and cannot see * class loader and this dictates which classes we can and cannot see
*/ */
fun get(name: String, sf: SerializerFactory) = sf.getTransformsCache().computeIfAbsent(name) { fun get(name: String, sf: SerializerFactory) = sf.transformsCache.computeIfAbsent(name) {
val transforms = EnumMap<TransformTypes, MutableList<Transform>>(TransformTypes::class.java) val transforms = EnumMap<TransformTypes, MutableList<Transform>>(TransformTypes::class.java)
try { try {
val clazz = sf.classloader.loadClass(name) val clazz = sf.classloader.loadClass(name)

View File

@ -48,7 +48,7 @@ private val toStringHelper: String = Type.getInternalName(MoreObjects.ToStringHe
// Allow us to create alternative ClassCarpenters. // Allow us to create alternative ClassCarpenters.
interface ClassCarpenter { interface ClassCarpenter {
val whitelist: ClassWhitelist val whitelist: ClassWhitelist
val classloader: CarpenterClassLoader val classloader: ClassLoader
fun build(schema: Schema): Class<*> fun build(schema: Schema): Class<*>
} }

View File

@ -30,8 +30,8 @@ abstract class ClassField(field: Class<out Any?>) : Field(field) {
abstract val nullabilityAnnotation: String abstract val nullabilityAnnotation: String
abstract fun nullTest(mv: MethodVisitor, slot: Int) abstract fun nullTest(mv: MethodVisitor, slot: Int)
override var descriptor = Type.getDescriptor(this.field) override var descriptor: String? = Type.getDescriptor(this.field)
override val type: String get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" override val type: String get() = if (this.field.isPrimitive) this.descriptor!! else "Ljava/lang/Object;"
fun addNullabilityAnnotation(mv: MethodVisitor) { fun addNullabilityAnnotation(mv: MethodVisitor) {
mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() mv.visitAnnotation(nullabilityAnnotation, true).visitEnd()

View File

@ -7,7 +7,7 @@ import static org.junit.Assert.*;
import java.io.NotSerializableException; import java.io.NotSerializableException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.concurrent.ConcurrentHashMap; import java.util.Map;
public class JavaPrivatePropertyTests { public class JavaPrivatePropertyTests {
static class C { static class C {
@ -116,7 +116,7 @@ public class JavaPrivatePropertyTests {
B3 b2 = des.deserialize(ser.serialize(b, TestSerializationContext.testSerializationContext), B3.class, TestSerializationContext.testSerializationContext); B3 b2 = des.deserialize(ser.serialize(b, TestSerializationContext.testSerializationContext), B3.class, TestSerializationContext.testSerializationContext);
// since we can't find a getter for b (isb != isB) then we won't serialize that parameter // since we can't find a getter for b (isb != isB) then we won't serialize that parameter
assertEquals (null, b2.b); assertNull (b2.b);
} }
@Test @Test
@ -154,8 +154,7 @@ public class JavaPrivatePropertyTests {
Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor"); Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor");
f.setAccessible(true); f.setAccessible(true);
ConcurrentHashMap<Object, AMQPSerializer<Object>> serializersByDescriptor = Map<?, AMQPSerializer<?>> serializersByDescriptor = (Map<?, AMQPSerializer<?>>) f.get(factory);
(ConcurrentHashMap<Object, AMQPSerializer<Object>>) f.get(factory);
assertEquals(1, serializersByDescriptor.size()); assertEquals(1, serializersByDescriptor.size());
ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]);
@ -185,8 +184,7 @@ public class JavaPrivatePropertyTests {
// //
Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor"); Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor");
f.setAccessible(true); f.setAccessible(true);
ConcurrentHashMap<Object, AMQPSerializer<Object>> serializersByDescriptor = Map<?, AMQPSerializer<?>> serializersByDescriptor = (Map<?, AMQPSerializer<?>>) f.get(factory);
(ConcurrentHashMap<Object, AMQPSerializer<Object>>) f.get(factory);
assertEquals(1, serializersByDescriptor.size()); assertEquals(1, serializersByDescriptor.size());
ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]);

View File

@ -10,7 +10,6 @@ import net.corda.core.identity.CordaX500Name
import net.corda.nodeapi.internal.serialization.amqp.testutils.* import net.corda.nodeapi.internal.serialization.amqp.testutils.*
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.test.assertEquals import kotlin.test.assertEquals
data class TestContractState( data class TestContractState(
@ -36,7 +35,7 @@ class GenericsTests {
private fun <T : Any> BytesAndSchemas<T>.printSchema() = if (VERBOSE) println("${this.schema}\n") else Unit private fun <T : Any> BytesAndSchemas<T>.printSchema() = if (VERBOSE) println("${this.schema}\n") else Unit
private fun ConcurrentHashMap<Any, AMQPSerializer<Any>>.printKeyToType() { private fun MutableMap<Any, AMQPSerializer<Any>>.printKeyToType() {
if (!VERBOSE) return if (!VERBOSE) return
forEach { forEach {
@ -53,11 +52,11 @@ class GenericsTests {
val bytes1 = SerializationOutput(factory).serializeAndReturnSchema(G("hi")).apply { printSchema() } val bytes1 = SerializationOutput(factory).serializeAndReturnSchema(G("hi")).apply { printSchema() }
factory.getSerializersByDescriptor().printKeyToType() factory.serializersByDescriptor.printKeyToType()
val bytes2 = SerializationOutput(factory).serializeAndReturnSchema(G(121)).apply { printSchema() } val bytes2 = SerializationOutput(factory).serializeAndReturnSchema(G(121)).apply { printSchema() }
factory.getSerializersByDescriptor().printKeyToType() factory.serializersByDescriptor.printKeyToType()
listOf(factory, testDefaultFactory()).forEach { f -> listOf(factory, testDefaultFactory()).forEach { f ->
DeserializationInput(f).deserialize(bytes1.obj).apply { assertEquals("hi", this.a) } DeserializationInput(f).deserialize(bytes1.obj).apply { assertEquals("hi", this.a) }
@ -90,14 +89,14 @@ class GenericsTests {
val bytes = ser.serializeAndReturnSchema(G("hi")).apply { printSchema() } val bytes = ser.serializeAndReturnSchema(G("hi")).apply { printSchema() }
factory.getSerializersByDescriptor().printKeyToType() factory.serializersByDescriptor.printKeyToType()
assertEquals("hi", DeserializationInput(factory).deserialize(bytes.obj).a) assertEquals("hi", DeserializationInput(factory).deserialize(bytes.obj).a)
assertEquals("hi", DeserializationInput(altContextFactory).deserialize(bytes.obj).a) assertEquals("hi", DeserializationInput(altContextFactory).deserialize(bytes.obj).a)
val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, G("hi"))).apply { printSchema() } val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, G("hi"))).apply { printSchema() }
factory.getSerializersByDescriptor().printKeyToType() factory.serializersByDescriptor.printKeyToType()
printSeparator() printSeparator()
@ -149,21 +148,21 @@ class GenericsTests {
ser.serialize(Wrapper(Container(InnerA(1)))).apply { ser.serialize(Wrapper(Container(InnerA(1)))).apply {
factories.forEach { factories.forEach {
DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_a) } DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_a) }
it.getSerializersByDescriptor().printKeyToType(); printSeparator() it.serializersByDescriptor.printKeyToType(); printSeparator()
} }
} }
ser.serialize(Wrapper(Container(InnerB(1)))).apply { ser.serialize(Wrapper(Container(InnerB(1)))).apply {
factories.forEach { factories.forEach {
DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_b) } DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_b) }
it.getSerializersByDescriptor().printKeyToType(); printSeparator() it.serializersByDescriptor.printKeyToType(); printSeparator()
} }
} }
ser.serialize(Wrapper(Container(InnerC("Ho ho ho")))).apply { ser.serialize(Wrapper(Container(InnerC("Ho ho ho")))).apply {
factories.forEach { factories.forEach {
DeserializationInput(it).deserialize(this).apply { assertEquals("Ho ho ho", c.b.a_c) } DeserializationInput(it).deserialize(this).apply { assertEquals("Ho ho ho", c.b.a_c) }
it.getSerializersByDescriptor().printKeyToType(); printSeparator() it.serializersByDescriptor.printKeyToType(); printSeparator()
} }
} }
} }
@ -199,7 +198,7 @@ class GenericsTests {
a: ForceWildcard<*>, a: ForceWildcard<*>,
factory: SerializerFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())): SerializedBytes<*> { factory: SerializerFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())): SerializedBytes<*> {
val bytes = SerializationOutput(factory).serializeAndReturnSchema(a) val bytes = SerializationOutput(factory).serializeAndReturnSchema(a)
factory.getSerializersByDescriptor().printKeyToType() factory.serializersByDescriptor.printKeyToType()
bytes.printSchema() bytes.printSchema()
return bytes.obj return bytes.obj
} }