diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt index 45df43799e..1a56cf901b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt @@ -1,5 +1,7 @@ package net.corda.nodeapi.internal.serialization.amqp +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.CordaSerializable import com.google.common.primitives.Primitives import com.google.common.reflect.TypeToken import net.corda.core.serialization.SerializationContext @@ -51,7 +53,7 @@ internal fun constructorForDeserialization(type: Type): KFunction? { } } - return preferredCandidate?.apply { isAccessible = true} + return preferredCandidate?.apply { isAccessible = true } ?: throw NotSerializableException("No constructor for deserialization found for $clazz.") } else { return null @@ -81,7 +83,7 @@ private fun propertiesForSerializationFromConstructor(kotlinConstructo val name = param.name ?: throw NotSerializableException("Constructor parameter of $clazz has no name.") val matchingProperty = properties[name] ?: throw NotSerializableException("No property matching constructor parameter named $name of $clazz." + - " If using Java, check that you have the -parameters option specified in the Java compiler.") + " If using Java, check that you have the -parameters option specified in the Java compiler.") // Check that the method has a getter in java. val getter = matchingProperty.readMethod ?: throw NotSerializableException("Property has no getter method for $name of $clazz." + " If using Java and the parameter name looks anonymous, check that you have the -parameters option specified in the Java compiler.") @@ -123,7 +125,7 @@ private fun exploreType(type: Type?, interfaces: MutableSet, serializerFac val clazz = type?.asClass() if (clazz != null) { if (clazz.isInterface) { - if(serializerFactory.isNotWhitelisted(clazz)) return // We stop exploring once we reach a branch that has no `CordaSerializable` annotation or whitelisting. + if (serializerFactory.whitelist.isNotWhitelisted(clazz)) return // We stop exploring once we reach a branch that has no `CordaSerializable` annotation or whitelisting. else interfaces += type } for (newInterface in clazz.genericInterfaces) { @@ -263,3 +265,21 @@ private fun Throwable.setMessage(newMsg: String) { detailMessageField.isAccessible = true detailMessageField.set(this, newMsg) } + +fun ClassWhitelist.whitelisted(type: Type) { + val clazz = type.asClass()!! + if (isNotWhitelisted(clazz)) { + throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.") + } +} + +// Ignore SimpleFieldAccess as we add it to anything we build in the carpenter. +fun ClassWhitelist.isNotWhitelisted(clazz: Class<*>) = + !(hasListed(clazz) || hasAnnotationInHierarchy(clazz)) + +// Recursively check the class, interfaces and superclasses for our annotation. +fun ClassWhitelist.hasAnnotationInHierarchy(type: Class<*>): Boolean { + return type.isAnnotationPresent(CordaSerializable::class.java) + || type.interfaces.any { hasAnnotationInHierarchy(it) } + || (type.superclass != null && hasAnnotationInHierarchy(type.superclass)) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index 2f0a734800..65083a6150 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -4,7 +4,6 @@ import com.google.common.primitives.Primitives import com.google.common.reflect.TypeResolver import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.ClassWhitelist -import net.corda.core.serialization.CordaSerializable import net.corda.nodeapi.internal.serialization.carpenter.* import org.apache.qpid.proton.amqp.* import java.io.NotSerializableException @@ -243,10 +242,10 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) else ArraySerializer.make(type, this) } else if (clazz.kotlin.objectInstance != null) { - whitelisted(clazz) + whitelist.whitelisted(clazz) SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this) } else { - whitelisted(type) + whitelist.whitelisted(type) ObjectSerializer(type, this) } } @@ -271,24 +270,6 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { return null } - private fun whitelisted(type: Type) { - val clazz = type.asClass()!! - if (isNotWhitelisted(clazz)) { - throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.") - } - } - - // Ignore SimpleFieldAccess as we add it to anything we build in the carpenter. - internal fun isNotWhitelisted(clazz: Class<*>): Boolean = clazz == SimpleFieldAccess::class.java || - (!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz)) - - // Recursively check the class, interfaces and superclasses for our annotation. - private fun hasAnnotationInHierarchy(type: Class<*>): Boolean { - return type.isAnnotationPresent(CordaSerializable::class.java) || - type.interfaces.any { hasAnnotationInHierarchy(it) } - || (type.superclass != null && hasAnnotationInHierarchy(type.superclass)) - } - private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer { val rawType = declaredType.rawType as Class<*> rawType.checkSupportedMapType() diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index 9f727e09c7..7a0dee0703 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -2,14 +2,13 @@ package net.corda.nodeapi.internal.serialization.carpenter import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.amqp.whitelisted import org.objectweb.asm.ClassWriter import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type - import java.lang.Character.isJavaIdentifierPart import java.lang.Character.isJavaIdentifierStart - import java.util.* /** @@ -17,6 +16,7 @@ import java.util.* * as if `this.class.getMethod("get" + name.capitalize()).invoke(this)` had been called. It is intended as a more * convenient alternative to reflection. */ +@CordaSerializable interface SimpleFieldAccess { operator fun get(name: String): Any? } @@ -134,7 +134,10 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visit(TARGET_VERSION, ACC_PUBLIC + ACC_FINAL + ACC_SUPER + ACC_ENUM, schema.jvmName, "L$jlEnum;", jlEnum, null) - visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + } + generateFields(schema) generateStaticEnumConstructor(schema) generateEnumConstructor() @@ -151,8 +154,10 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader cw.apply { visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, jlObject, interfaces) - visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + } generateAbstractGetters(schema) }.visitEnd() } @@ -163,20 +168,25 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader val superName = schema.superclass?.jvmName ?: jlObject val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() - if (SimpleFieldAccess::class.java !in schema.interfaces) { + if (SimpleFieldAccess::class.java !in schema.interfaces && ( + (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) || + (schema.flags.getOrDefault(SchemaFlags.NoSimpleFieldAccess, "false") == true))) { interfaces.add(SimpleFieldAccess::class.java.name.jvm) } cw.apply { visit(TARGET_VERSION, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray()) - visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + if (schema.flags.getOrDefault(SchemaFlags.NotCordaSerializable, "false") == true) { + visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + } generateFields(schema) generateClassConstructor(schema) generateGetters(schema) - if (schema.superclass == null) + if (schema.superclass == null) { generateGetMethod() // From SimplePropertyAccess + } generateToString(schema) }.visitEnd() } @@ -388,11 +398,22 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } } + /** + * If a sub element isn't whitelist we will not build a class containing that type as a member. Since, by + * default, classes created by the [ClassCarpenter] are annotated as [CordaSerializable] we will always + * be able to carpent classes generated from our AMQP library as, at a base level, we will either be able to + * create the lowest level in the meta hierarchy because either all members are jvm primitives or + * whitelisted classes + */ private fun validateSchema(schema: Schema) { if (schema.name in _loaded) throw DuplicateNameException() fun isJavaName(n: String) = n.isNotBlank() && isJavaIdentifierStart(n.first()) && n.all(::isJavaIdentifierPart) require(isJavaName(schema.name.split(".").last())) { "Not a valid Java name: ${schema.name}" } - schema.fields.keys.forEach { require(isJavaName(it)) { "Not a valid Java name: $it" } } + schema.fields.forEach { + require(isJavaName(it.key)) { "Not a valid Java name: $it" } + whitelist.whitelisted(it.value.field) + } + // Now check each interface we've been asked to implement, as the JVM will unfortunately only catch the // fact that we didn't implement the interface we said we would at the moment the missing method is // actually called, which is a bit too dynamic for my tastes. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt index e6684756b8..bd98a4f2c7 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt @@ -1,9 +1,12 @@ package net.corda.nodeapi.internal.serialization.carpenter -import kotlin.collections.LinkedHashMap import org.objectweb.asm.ClassWriter import org.objectweb.asm.Opcodes.* +enum class SchemaFlags { + NoSimpleFieldAccess, NotCordaSerializable +} + /** * A Schema is the representation of an object the Carpenter can contsruct * @@ -17,7 +20,8 @@ abstract class Schema( var fields: Map, val superclass: Schema? = null, val interfaces: List> = emptyList(), - updater: (String, Field) -> Unit) { + updater: (String, Field) -> Unit, + var flags : MutableMap = mutableMapOf()) { private fun Map.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor }) init { @@ -41,6 +45,20 @@ abstract class Schema( val asArray: String get() = "[L$jvmName;" + + @Suppress("Unused") + fun setNoSimpleFieldAccess() { + flags.replace (SchemaFlags.NoSimpleFieldAccess, true) + } + + fun setNotCordaSerializable() { + flags.replace (SchemaFlags.NotCordaSerializable, true) + } + + @Suppress("Unused") + fun setCordaSerializable() { + flags.replace (SchemaFlags.NotCordaSerializable, false) + } } /** diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt new file mode 100644 index 0000000000..f2da9018ef --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterWhitelistTest.kt @@ -0,0 +1,100 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.CordaSerializable +import org.assertj.core.api.Assertions +import org.junit.Test +import java.io.NotSerializableException + +class ClassCarpenterWhitelistTest { + + // whitelisting a class on the class path will mean we will carpente up a class that + // contains it as a member + @Test + fun whitelisted() { + data class A(val a: Int) + + class WL : ClassWhitelist { + private val allowedClasses = hashSetOf( + A::class.java.name + ) + + override fun hasListed(type: Class<*>): Boolean = type.name in allowedClasses + } + + val cc = ClassCarpenter(whitelist = WL()) + + // if this works, the test works, if it throws then we're in a world of pain, we could + // go further but there are a lot of other tests that test weather we can build + // carpented objects + cc.build(ClassSchema("thing", mapOf("a" to NonNullableField(A::class.java)))) + } + + // However, a class on the class path that isn't whitelisted we will not create + // an object that contains a member of that type + @Test + fun notWhitelisted() { + data class A(val a: Int) + + class WL : ClassWhitelist { + override fun hasListed(type: Class<*>) = false + } + + val cc = ClassCarpenter(whitelist = WL()) + + // Class A isn't on the whitelist, so we should fail to carpent it + Assertions.assertThatThrownBy { + cc.build(ClassSchema("thing", mapOf("a" to NonNullableField(A::class.java)))) + }.isInstanceOf(NotSerializableException::class.java) + } + + // despite now being whitelisted and on the class path, we will carpent this because + // it's marked as CordaSerializable + @Test + fun notWhitelistedButAnnotated() { + @CordaSerializable data class A(val a: Int) + + class WL : ClassWhitelist { + override fun hasListed(type: Class<*>) = false + } + + val cc = ClassCarpenter(whitelist = WL()) + + // again, simply not throwing here is enough to show the test worked and the carpenter + // didn't reject the type even though it wasn't on the whitelist because it was + // annotated properly + cc.build(ClassSchema("thing", mapOf("a" to NonNullableField(A::class.java)))) + } + + @Test + fun notWhitelistedButCarpented() { + // just have the white list reject *Everything* except ints + class WL : ClassWhitelist { + override fun hasListed(type: Class<*>) = type.name == "int" + } + + val cc = ClassCarpenter(whitelist = WL()) + + val schema1a = ClassSchema("thing1a", mapOf("a" to NonNullableField(Int::class.java))) + + // thing 1 won't be set as corda serializable, meaning we won't build schema 2 + schema1a.setNotCordaSerializable() + + val clazz1a = cc.build(schema1a) + val schema2 = ClassSchema("thing2", mapOf("a" to NonNullableField(clazz1a))) + + // thing 2 references thing 1 which wasn't carpented as corda s erializable and thus + // this will fail + Assertions.assertThatThrownBy { + cc.build(schema2) + }.isInstanceOf(NotSerializableException::class.java) + + // create a second type of schema1, this time leave it as corda serialzable + val schema1b = ClassSchema("thing1b", mapOf("a" to NonNullableField(Int::class.java))) + + val clazz1b = cc.build(schema1b) + + // since schema 1b was created as CordaSerializable this will work + val schema2b = ClassSchema("thing2", mapOf("a" to NonNullableField(clazz1b))) + } +} \ No newline at end of file