diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt index ccdd67a718..46de47d12a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt @@ -58,7 +58,7 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { * subclass in the schema, so that we can distinguish between subclasses. */ // TODO: should this be a custom serializer at all, or should it just be a plain AMQPSerializer? - class SubClass(protected val clazz: Class<*>, protected val superClassSerializer: CustomSerializer) : CustomSerializer() { + class SubClass(private val clazz: Class<*>, private val superClassSerializer: CustomSerializer) : CustomSerializer() { // TODO: should this be empty or contain the schema of the super? override val schemaForDocumentation = Schema(emptyList()) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EvolutionSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EvolutionSerializer.kt index 28a72d984a..691921bc8c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EvolutionSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EvolutionSerializer.kt @@ -239,7 +239,7 @@ class EvolutionSerializerGetter : EvolutionSerializerGetterBase() { typeNotation: TypeNotation, newSerializer: AMQPSerializer, schemas: SerializationSchemas): AMQPSerializer { - return factory.getSerializersByDescriptor().computeIfAbsent(typeNotation.descriptor.name!!) { + return factory.serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { when (typeNotation) { is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, factory) is RestrictedType -> { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index b1e27f8b36..8a56baf1b0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -40,32 +40,40 @@ open class SerializerFactory( val whitelist: ClassWhitelist, val classCarpenter: ClassCarpenter, private val evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(), - val fingerPrinter: FingerPrinter = SerializerFingerPrinter()) { + val fingerPrinter: FingerPrinter = SerializerFingerPrinter(), + private val serializersByType: MutableMap>, + val serializersByDescriptor: MutableMap>, + private val customSerializers: MutableList, + val transformsCache: MutableMap>>) { + constructor(whitelist: ClassWhitelist, + classCarpenter: ClassCarpenter, + evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(), + fingerPrinter: FingerPrinter = SerializerFingerPrinter() + ) : this(whitelist, classCarpenter, evolutionSerializerGetter, fingerPrinter, + serializersByType = ConcurrentHashMap(), + serializersByDescriptor = ConcurrentHashMap(), + customSerializers = CopyOnWriteArrayList(), + transformsCache = ConcurrentHashMap()) constructor(whitelist: ClassWhitelist, classLoader: ClassLoader, evolutionSerializerGetter: EvolutionSerializerGetterBase = EvolutionSerializerGetter(), fingerPrinter: FingerPrinter = SerializerFingerPrinter() - ) : this(whitelist, ClassCarpenterImpl(classLoader, whitelist), evolutionSerializerGetter, fingerPrinter) + ) : this(whitelist, ClassCarpenterImpl(classLoader, whitelist), evolutionSerializerGetter, fingerPrinter, + serializersByType = ConcurrentHashMap(), + serializersByDescriptor = ConcurrentHashMap(), + customSerializers = CopyOnWriteArrayList(), + transformsCache = ConcurrentHashMap()) init { fingerPrinter.setOwner(this) } - private val serializersByType = ConcurrentHashMap>() - private val serializersByDescriptor = ConcurrentHashMap>() - private val customSerializers = CopyOnWriteArrayList() - private val transformsCache = ConcurrentHashMap>>() - val classloader: ClassLoader get() = classCarpenter.classloader private fun getEvolutionSerializer(typeNotation: TypeNotation, newSerializer: AMQPSerializer, schemas: SerializationSchemas) = evolutionSerializerGetter.getEvolutionSerializer(this, typeNotation, newSerializer, schemas) - fun getSerializersByDescriptor() = serializersByDescriptor - - fun getTransformsCache() = transformsCache - /** * Look up, and manufacture if necessary, a serializer for the given type. * @@ -219,7 +227,7 @@ open class SerializerFactory( /** * Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and, - * if not, use the [ClassCarpenter] to generate a class to use in it's place. + * if not, use the [ClassCarpenter] to generate a class to use in its place. */ private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor, sentinel: Boolean = false) { val metaSchema = CarpenterMetaSchema.newInstance() @@ -239,24 +247,28 @@ open class SerializerFactory( } if (metaSchema.isNotEmpty()) { - val mc = MetaCarpenter(metaSchema, classCarpenter) - try { - mc.build() - } catch (e: MetaCarpenterException) { - // preserve the actual message locally - loggerFor().apply { - error("${e.message} [hint: enable trace debugging for the stack trace]") - trace("", e) - } - - // prevent carpenter exceptions escaping into the world, convert things into a nice - // NotSerializableException for when this escapes over the wire - throw NotSerializableException(e.name) - } - processSchema(schemaAndDescriptor, true) + runCarpentry(schemaAndDescriptor, metaSchema) } } + private fun runCarpentry(schemaAndDescriptor: FactorySchemaAndDescriptor, metaSchema: CarpenterMetaSchema) { + val mc = MetaCarpenter(metaSchema, classCarpenter) + try { + mc.build() + } catch (e: MetaCarpenterException) { + // preserve the actual message locally + loggerFor().apply { + error("${e.message} [hint: enable trace debugging for the stack trace]") + trace("", e) + } + + // prevent carpenter exceptions escaping into the world, convert things into a nice + // NotSerializableException for when this escapes over the wire + throw NotSerializableException(e.name) + } + processSchema(schemaAndDescriptor, true) + } + private fun processSchemaEntry(typeNotation: TypeNotation) = when (typeNotation) { is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface) is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt index edb7f2711d..ee695fd7bb 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt @@ -200,7 +200,7 @@ data class TransformsSchema(val types: Map>(TransformTypes::class.java) try { val clazz = sf.classloader.loadClass(name) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index 42eb64f824..f828041a36 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -48,7 +48,7 @@ private val toStringHelper: String = Type.getInternalName(MoreObjects.ToStringHe // Allow us to create alternative ClassCarpenters. interface ClassCarpenter { val whitelist: ClassWhitelist - val classloader: CarpenterClassLoader + val classloader: ClassLoader fun build(schema: Schema): Class<*> } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt index 0ce9a8f3d1..351aaecb70 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt @@ -30,8 +30,8 @@ abstract class ClassField(field: Class) : Field(field) { abstract val nullabilityAnnotation: String abstract fun nullTest(mv: MethodVisitor, slot: Int) - override var descriptor = Type.getDescriptor(this.field) - override val type: String get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" + override var descriptor: String? = Type.getDescriptor(this.field) + override val type: String get() = if (this.field.isPrimitive) this.descriptor!! else "Ljava/lang/Object;" fun addNullabilityAnnotation(mv: MethodVisitor) { mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaPrivatePropertyTests.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaPrivatePropertyTests.java index 75e6d44ea7..2c0fe80a9e 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaPrivatePropertyTests.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaPrivatePropertyTests.java @@ -7,7 +7,7 @@ import static org.junit.Assert.*; import java.io.NotSerializableException; import java.lang.reflect.Field; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Map; public class JavaPrivatePropertyTests { static class C { @@ -116,7 +116,7 @@ public class JavaPrivatePropertyTests { B3 b2 = des.deserialize(ser.serialize(b, TestSerializationContext.testSerializationContext), B3.class, TestSerializationContext.testSerializationContext); // since we can't find a getter for b (isb != isB) then we won't serialize that parameter - assertEquals (null, b2.b); + assertNull (b2.b); } @Test @@ -154,8 +154,7 @@ public class JavaPrivatePropertyTests { Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor"); f.setAccessible(true); - ConcurrentHashMap> serializersByDescriptor = - (ConcurrentHashMap>) f.get(factory); + Map> serializersByDescriptor = (Map>) f.get(factory); assertEquals(1, serializersByDescriptor.size()); ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); @@ -185,8 +184,7 @@ public class JavaPrivatePropertyTests { // Field f = SerializerFactory.class.getDeclaredField("serializersByDescriptor"); f.setAccessible(true); - ConcurrentHashMap> serializersByDescriptor = - (ConcurrentHashMap>) f.get(factory); + Map> serializersByDescriptor = (Map>) f.get(factory); assertEquals(1, serializersByDescriptor.size()); ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt index 590cdfc5ca..3dd237b89e 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt @@ -10,7 +10,6 @@ import net.corda.core.identity.CordaX500Name import net.corda.nodeapi.internal.serialization.amqp.testutils.* import net.corda.testing.core.TestIdentity import java.util.* -import java.util.concurrent.ConcurrentHashMap import kotlin.test.assertEquals data class TestContractState( @@ -36,7 +35,7 @@ class GenericsTests { private fun BytesAndSchemas.printSchema() = if (VERBOSE) println("${this.schema}\n") else Unit - private fun ConcurrentHashMap>.printKeyToType() { + private fun MutableMap>.printKeyToType() { if (!VERBOSE) return forEach { @@ -53,11 +52,11 @@ class GenericsTests { val bytes1 = SerializationOutput(factory).serializeAndReturnSchema(G("hi")).apply { printSchema() } - factory.getSerializersByDescriptor().printKeyToType() + factory.serializersByDescriptor.printKeyToType() val bytes2 = SerializationOutput(factory).serializeAndReturnSchema(G(121)).apply { printSchema() } - factory.getSerializersByDescriptor().printKeyToType() + factory.serializersByDescriptor.printKeyToType() listOf(factory, testDefaultFactory()).forEach { f -> DeserializationInput(f).deserialize(bytes1.obj).apply { assertEquals("hi", this.a) } @@ -90,14 +89,14 @@ class GenericsTests { val bytes = ser.serializeAndReturnSchema(G("hi")).apply { printSchema() } - factory.getSerializersByDescriptor().printKeyToType() + factory.serializersByDescriptor.printKeyToType() assertEquals("hi", DeserializationInput(factory).deserialize(bytes.obj).a) assertEquals("hi", DeserializationInput(altContextFactory).deserialize(bytes.obj).a) val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, G("hi"))).apply { printSchema() } - factory.getSerializersByDescriptor().printKeyToType() + factory.serializersByDescriptor.printKeyToType() printSeparator() @@ -149,21 +148,21 @@ class GenericsTests { ser.serialize(Wrapper(Container(InnerA(1)))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_a) } - it.getSerializersByDescriptor().printKeyToType(); printSeparator() + it.serializersByDescriptor.printKeyToType(); printSeparator() } } ser.serialize(Wrapper(Container(InnerB(1)))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_b) } - it.getSerializersByDescriptor().printKeyToType(); printSeparator() + it.serializersByDescriptor.printKeyToType(); printSeparator() } } ser.serialize(Wrapper(Container(InnerC("Ho ho ho")))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals("Ho ho ho", c.b.a_c) } - it.getSerializersByDescriptor().printKeyToType(); printSeparator() + it.serializersByDescriptor.printKeyToType(); printSeparator() } } } @@ -199,7 +198,7 @@ class GenericsTests { a: ForceWildcard<*>, factory: SerializerFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())): SerializedBytes<*> { val bytes = SerializationOutput(factory).serializeAndReturnSchema(a) - factory.getSerializersByDescriptor().printKeyToType() + factory.serializersByDescriptor.printKeyToType() bytes.printSchema() return bytes.obj }