diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt index 09983c8c68..45f98cbda9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt @@ -11,6 +11,7 @@ import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.loggerFor +import net.corda.nodeapi.internal.serialization.amqp.hasAnnotationInHierarchy import java.io.PrintWriter import java.lang.reflect.Modifier.isAbstract import java.nio.charset.StandardCharsets @@ -115,13 +116,7 @@ class CordaClassResolver(serializationContext: SerializationContext) : DefaultCl return (type.classLoader !is AttachmentsClassLoader) && !KryoSerializable::class.java.isAssignableFrom(type) && !type.isAnnotationPresent(DefaultSerializer::class.java) - && (type.isAnnotationPresent(CordaSerializable::class.java) || hasInheritedAnnotation(type)) - } - - // Recursively check interfaces for our annotation. - private fun hasInheritedAnnotation(type: Class<*>): Boolean { - return type.interfaces.any { it.isAnnotationPresent(CordaSerializable::class.java) || hasInheritedAnnotation(it) } - || (type.superclass != null && hasInheritedAnnotation(type.superclass)) + && (type.isAnnotationPresent(CordaSerializable::class.java) || whitelist.hasAnnotationInHierarchy(type)) } // Need to clear out class names from attachments. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt index 1a56cf901b..ab218fb247 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt @@ -266,16 +266,14 @@ private fun Throwable.setMessage(newMsg: String) { detailMessageField.set(this, newMsg) } -fun ClassWhitelist.whitelisted(type: Type) { - val clazz = type.asClass()!! - if (isNotWhitelisted(clazz)) { +fun ClassWhitelist.requireWhitelisted(type: Type) { + if (!this.isWhitelisted(type.asClass()!!)) { throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.") } } -// Ignore SimpleFieldAccess as we add it to anything we build in the carpenter. -fun ClassWhitelist.isNotWhitelisted(clazz: Class<*>) = - !(hasListed(clazz) || hasAnnotationInHierarchy(clazz)) +fun ClassWhitelist.isWhitelisted(clazz: Class<*>) = (hasListed(clazz) || hasAnnotationInHierarchy(clazz)) +fun ClassWhitelist.isNotWhitelisted(clazz: Class<*>) = !(this.isWhitelisted(clazz)) // Recursively check the class, interfaces and superclasses for our annotation. fun ClassWhitelist.hasAnnotationInHierarchy(type: Class<*>): Boolean { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index 65083a6150..da8ae1afc6 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -30,11 +30,11 @@ data class FactorySchemaAndDescriptor(val schema: Schema, val typeDescriptor: An // TODO: need to rethink matching of constructor to properties in relation to implementing interfaces and needing those properties etc. // TODO: need to support super classes as well as interfaces with our current code base... what's involved? If we continue to ban, what is the impact? @ThreadSafe -class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { +open class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { private val serializersByType = ConcurrentHashMap>() private val serializersByDescriptor = ConcurrentHashMap>() private val customSerializers = CopyOnWriteArrayList>() - val classCarpenter = ClassCarpenter(cl, whitelist) + open val classCarpenter = ClassCarpenter(cl, whitelist) val classloader: ClassLoader get() = classCarpenter.classloader @@ -81,7 +81,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { } } Enum::class.java.isAssignableFrom(actualClass ?: declaredClass) -> serializersByType.computeIfAbsent(actualClass ?: declaredClass) { - whitelisted(actualType) + whitelist.requireWhitelisted(actualType) EnumSerializer(actualType, actualClass ?: declaredClass, this) } else -> makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) @@ -242,10 +242,10 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) else ArraySerializer.make(type, this) } else if (clazz.kotlin.objectInstance != null) { - whitelist.whitelisted(clazz) + whitelist.requireWhitelisted(clazz) SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this) } else { - whitelist.whitelisted(type) + whitelist.requireWhitelisted(type) ObjectSerializer(type, this) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index 7a0dee0703..2c98ebbc07 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -2,7 +2,6 @@ package net.corda.nodeapi.internal.serialization.carpenter import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable -import net.corda.nodeapi.internal.serialization.amqp.whitelisted import org.objectweb.asm.ClassWriter import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* @@ -134,7 +133,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visit(TARGET_VERSION, ACC_PUBLIC + ACC_FINAL + ACC_SUPER + ACC_ENUM, schema.jvmName, "L$jlEnum;", jlEnum, null) - if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + if (schema.flags.cordaSerializable()) { visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() } @@ -155,7 +154,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, jlObject, interfaces) - if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + if (schema.flags.cordaSerializable()) { visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() } generateAbstractGetters(schema) @@ -168,9 +167,9 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader val superName = schema.superclass?.jvmName ?: jlObject val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() - if (SimpleFieldAccess::class.java !in schema.interfaces && ( - (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) || - (schema.flags.getOrDefault(SchemaFlags.NoSimpleFieldAccess, "false") == true))) { + if (SimpleFieldAccess::class.java !in schema.interfaces + && schema.flags.cordaSerializable() + && schema.flags.simpleFieldAccess()) { interfaces.add(SimpleFieldAccess::class.java.name.jvm) } @@ -178,7 +177,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visit(TARGET_VERSION, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray()) - if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + if (schema.flags.cordaSerializable()) { visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() } generateFields(schema) @@ -411,7 +410,6 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader require(isJavaName(schema.name.split(".").last())) { "Not a valid Java name: ${schema.name}" } schema.fields.forEach { require(isJavaName(it.key)) { "Not a valid Java name: $it" } - whitelist.whitelisted(it.value.field) } // Now check each interface we've been asked to implement, as the JVM will unfortunately only catch the diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt index bd98a4f2c7..e3a3bfab56 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt @@ -2,9 +2,10 @@ package net.corda.nodeapi.internal.serialization.carpenter import org.objectweb.asm.ClassWriter import org.objectweb.asm.Opcodes.* +import java.util.* enum class SchemaFlags { - NoSimpleFieldAccess, NotCordaSerializable + SimpleFieldAccess, CordaSerializable } /** @@ -20,10 +21,11 @@ abstract class Schema( var fields: Map, val superclass: Schema? = null, val interfaces: List> = emptyList(), - updater: (String, Field) -> Unit, - var flags : MutableMap = mutableMapOf()) { + updater: (String, Field) -> Unit) { private fun Map.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor }) + var flags : EnumMap = EnumMap(SchemaFlags::class.java) + init { fields.forEach { updater(it.key, it.value) } @@ -46,19 +48,17 @@ abstract class Schema( val asArray: String get() = "[L$jvmName;" - @Suppress("Unused") - fun setNoSimpleFieldAccess() { - flags.replace (SchemaFlags.NoSimpleFieldAccess, true) + fun unsetCordaSerializable() { + flags.replace (SchemaFlags.CordaSerializable, false) } +} - fun setNotCordaSerializable() { - flags.replace (SchemaFlags.NotCordaSerializable, true) - } +fun EnumMap.cordaSerializable() : Boolean { + return this.getOrDefault(SchemaFlags.CordaSerializable, true) == true +} - @Suppress("Unused") - fun setCordaSerializable() { - flags.replace (SchemaFlags.NotCordaSerializable, false) - } +fun EnumMap.simpleFieldAccess() : Boolean { + return this.getOrDefault(SchemaFlags.SimpleFieldAccess, true) == true } /** diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt new file mode 100644 index 0000000000..47decde9c0 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.kt @@ -0,0 +1,144 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.SerializedBytes +import net.corda.nodeapi.internal.serialization.AllWhitelist +import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter +import org.assertj.core.api.Assertions +import org.junit.Test +import java.io.File +import java.io.NotSerializableException +import java.lang.reflect.Type +import java.util.concurrent.ConcurrentHashMap +import kotlin.test.assertEquals + +class InStatic : Exception ("Help!, help!, I'm being repressed") + +class C { + companion object { + init { + throw InStatic() + } + } +} + +// To re-setup the resource file for the tests +// * deserializeTest +// * deserializeTest2 +// comment out the companion object from here, comment out the test code and uncomment +// the generation code, then re-run the test and copy the file shown in the output print +// to the resource directory +class C2 (var b: Int) { + /* + companion object { + init { + throw InStatic() + } + } + */ +} + +class StaticInitialisationOfSerializedObjectTest { + @Test(expected=java.lang.ExceptionInInitializerError::class) + fun itBlowsUp() { + C() + } + + @Test + fun KotlinObjectWithCompanionObject() { + data class D (val c : C) + + val sf = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + + val typeMap = sf::class.java.getDeclaredField("serializersByType") + typeMap.isAccessible = true + + @Suppress("UNCHECKED_CAST") + val serialisersByType = typeMap.get(sf) as ConcurrentHashMap> + + // pre building a serializer, we shouldn't have anything registered + assertEquals(0, serialisersByType.size) + + // build a serializer for type D without an instance of it to serialise, since + // we can't actually construct one + sf.get(null, D::class.java) + + // post creation of the serializer we should have one element in the map, this + // proves we didn't statically construct an instance of C when building the serializer + assertEquals(1, serialisersByType.size) + } + + + @Test + fun deserializeTest() { + data class D (val c : C2) + + val path = EvolvabilityTests::class.java.getResource("StaticInitialisationOfSerializedObjectTest.deserializeTest") + val f = File(path.toURI()) + + // Original version of the class for the serialised version of this class + // + //val sf1 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + //val sc = SerializationOutput(sf1).serialize(D(C2(20))) + //f.writeBytes(sc.bytes) + //println (path) + + class WL : ClassWhitelist { + override fun hasListed(type: Class<*>) = + type.name == "net.corda.nodeapi.internal.serialization.amqp" + + ".StaticInitialisationOfSerializedObjectTest\$deserializeTest\$D" + } + + val sf2 = SerializerFactory(WL(), ClassLoader.getSystemClassLoader()) + val bytes = f.readBytes() + + Assertions.assertThatThrownBy { + DeserializationInput(sf2).deserialize(SerializedBytes(bytes)) + }.isInstanceOf(NotSerializableException::class.java) + } + + // Version of a serializer factory that will allow the class carpenter living on the + // factory to have a different whitelist applied to it than the factory + class TestSerializerFactory(wl1: ClassWhitelist, wl2: ClassWhitelist) : + SerializerFactory (wl1, ClassLoader.getSystemClassLoader()) { + override val classCarpenter = ClassCarpenter(ClassLoader.getSystemClassLoader(), wl2) + } + + // This time have the serilization factory and the carpenter use different whitelists + @Test + fun deserializeTest2() { + data class D (val c : C2) + + val path = EvolvabilityTests::class.java.getResource("StaticInitialisationOfSerializedObjectTest.deserializeTest2") + val f = File(path.toURI()) + + // Original version of the class for the serialised version of this class + // + //val sf1 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + //val sc = SerializationOutput(sf1).serialize(D(C2(20))) + //f.writeBytes(sc.bytes) + //println (path) + + // whitelist to be used by the serialisation factory + class WL1 : ClassWhitelist { + override fun hasListed(type: Class<*>) = + type.name == "net.corda.nodeapi.internal.serialization.amqp" + + ".StaticInitialisationOfSerializedObjectTest\$deserializeTest\$D" + } + + // whitelist to be used by the carpenter + class WL2 : ClassWhitelist { + override fun hasListed(type: Class<*>) = true + } + + val sf2 = TestSerializerFactory(WL1(), WL2()) + val bytes = f.readBytes() + + // Deserializing should throw because C is not on the whitelist NOT because + // we ever went anywhere near statically constructing it prior to not actually + // creating an instance of it + Assertions.assertThatThrownBy { + DeserializationInput(sf2).deserialize(SerializedBytes(bytes)) + }.isInstanceOf(NotSerializableException::class.java) + } +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt index f2da9018ef..deb51da65f 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt @@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.serialization.carpenter import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable import org.assertj.core.api.Assertions +import org.junit.Ignore import org.junit.Test import java.io.NotSerializableException @@ -30,9 +31,9 @@ class ClassCarpenterWhitelistTest { cc.build(ClassSchema("thing", mapOf("a" to NonNullableField(A::class.java)))) } - // However, a class on the class path that isn't whitelisted we will not create - // an object that contains a member of that type @Test + @Ignore("Currently the carpenter doesn't inspect it's whitelist so will carpent anything" + + "it's asked relying on the serializer factory to not ask for anything") fun notWhitelisted() { data class A(val a: Int) @@ -67,6 +68,8 @@ class ClassCarpenterWhitelistTest { } @Test + @Ignore("Currently the carpenter doesn't inspect it's whitelist so will carpent anything" + + "it's asked relying on the serializer factory to not ask for anything") fun notWhitelistedButCarpented() { // just have the white list reject *Everything* except ints class WL : ClassWhitelist { @@ -78,7 +81,7 @@ class ClassCarpenterWhitelistTest { val schema1a = ClassSchema("thing1a", mapOf("a" to NonNullableField(Int::class.java))) // thing 1 won't be set as corda serializable, meaning we won't build schema 2 - schema1a.setNotCordaSerializable() + schema1a.unsetCordaSerializable() val clazz1a = cc.build(schema1a) val schema2 = ClassSchema("thing2", mapOf("a" to NonNullableField(clazz1a))) @@ -95,6 +98,6 @@ class ClassCarpenterWhitelistTest { val clazz1b = cc.build(schema1b) // since schema 1b was created as CordaSerializable this will work - val schema2b = ClassSchema("thing2", mapOf("a" to NonNullableField(clazz1b))) + ClassSchema("thing2", mapOf("a" to NonNullableField(clazz1b))) } } \ No newline at end of file diff --git a/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest new file mode 100644 index 0000000000..0ba0299a3f Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest differ diff --git a/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest2 b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest2 new file mode 100644 index 0000000000..6423985eaf Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/StaticInitialisationOfSerializedObjectTest.deserializeTest2 differ