AMQP fixes: Don't attempt to include non-serializable interfaces in schemas (#1281)

* Ignore interfaces that are not serializable

* Annotate so test still works.

* Make methods accessible to serializer.

* Make sure interfaces are annotated in the carpenter.  Expand tests to check whitelisting.  Object is now whitelisted by default since it has no fields.

* Prevented Object from being whitelisted but allow arrays of Objects (i.e. pretty much an untyped array)
This commit is contained in:
Rick Parker 2017-08-21 10:27:27 +01:00 committed by GitHub
parent 06ad2ddd45
commit 56a84882a7
10 changed files with 53 additions and 17 deletions

View File

@ -25,7 +25,7 @@ class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPSerial
private val typeName = nameForType(clazz)
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
private val interfaces = interfacesForSerialization(clazz) // TODO maybe this proves too much and we need annotations to restrict.
private val interfaces = interfacesForSerialization(clazz, factory) // We restrict to only those annotated or whitelisted
internal val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields())

View File

@ -54,6 +54,7 @@ sealed class PropertySerializer(val name: String, val readMethod: Method, val re
companion object {
fun make(name: String, readMethod: Method, resolvedType: Type, factory: SerializerFactory): PropertySerializer {
readMethod.isAccessible = true
if (SerializerFactory.isPrimitive(resolvedType)) {
return when(resolvedType) {
Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod)

View File

@ -403,6 +403,6 @@ private fun fingerprintForObject(type: Type, contextType: Type?, alreadySeen: Mu
propertiesForSerialization(constructorForDeserialization(type), contextType ?: type, factory).fold(hasher.putUnencodedChars(name)) { orig, prop ->
fingerprintForType(prop.resolvedType, type, alreadySeen, orig, factory).putUnencodedChars(prop.name).putUnencodedChars(if (prop.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH)
}
interfacesForSerialization(type).map { fingerprintForType(it, type, alreadySeen, hasher, factory) }
interfacesForSerialization(type, factory).map { fingerprintForType(it, type, alreadySeen, hasher, factory) }
return hasher
}

View File

@ -103,23 +103,26 @@ private fun propertiesForSerializationFromAbstract(clazz: Class<*>, type: Type,
return rc
}
internal fun interfacesForSerialization(type: Type): List<Type> {
internal fun interfacesForSerialization(type: Type, serializerFactory: SerializerFactory): List<Type> {
val interfaces = LinkedHashSet<Type>()
exploreType(type, interfaces)
exploreType(type, interfaces, serializerFactory)
return interfaces.toList()
}
private fun exploreType(type: Type?, interfaces: MutableSet<Type>) {
private fun exploreType(type: Type?, interfaces: MutableSet<Type>, serializerFactory: SerializerFactory) {
val clazz = type?.asClass()
if (clazz != null) {
if (clazz.isInterface) interfaces += type
if (clazz.isInterface) {
if(serializerFactory.isNotWhitelisted(clazz)) return // We stop exploring once we reach a branch that has no `CordaSerializable` annotation or whitelisting.
else interfaces += type
}
for (newInterface in clazz.genericInterfaces) {
if (newInterface !in interfaces) {
exploreType(resolveTypeVariables(newInterface, type), interfaces)
exploreType(resolveTypeVariables(newInterface, type), interfaces, serializerFactory)
}
}
val superClass = clazz.genericSuperclass ?: return
exploreType(resolveTypeVariables(superClass, type), interfaces)
exploreType(resolveTypeVariables(superClass, type), interfaces, serializerFactory)
}
}

View File

@ -4,10 +4,7 @@ import com.google.common.primitives.Primitives
import com.google.common.reflect.TypeResolver
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.carpenter.CarpenterSchemas
import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter
import net.corda.nodeapi.internal.serialization.carpenter.MetaCarpenter
import net.corda.nodeapi.internal.serialization.carpenter.carpenterSchema
import net.corda.nodeapi.internal.serialization.carpenter.*
import org.apache.qpid.proton.amqp.*
import java.io.NotSerializableException
import java.lang.reflect.GenericArrayType
@ -224,7 +221,8 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
} else {
findCustomSerializer(clazz, declaredType) ?: run {
if (type.isArray()) {
whitelisted(type.componentType())
// Allow Object[] since this can be quite common (i.e. an untyped array)
if(type.componentType() != Object::class.java) whitelisted(type.componentType())
if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this)
else ArraySerializer.make(type, this)
} else if (clazz.kotlin.objectInstance != null) {
@ -258,11 +256,15 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
private fun whitelisted(type: Type) {
val clazz = type.asClass()!!
if (!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz)) {
if (isNotWhitelisted(clazz)) {
throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.")
}
}
// Ignore SimpleFieldAccess as we add it to anything we build in the carpenter.
internal fun isNotWhitelisted(clazz: Class<*>): Boolean = clazz == SimpleFieldAccess::class.java ||
(!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz))
// Recursively check the class, interfaces and superclasses for our annotation.
internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
return type.isAnnotationPresent(CordaSerializable::class.java) ||

View File

@ -10,7 +10,7 @@ import java.lang.reflect.Type
*/
class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: SerializerFactory) : AMQPSerializer<Any> {
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
private val interfaces = interfacesForSerialization(type)
private val interfaces = interfacesForSerialization(type, factory)
private fun generateProvides(): List<String> = interfaces.map { it.typeName }

View File

@ -1,8 +1,10 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import java.lang.Character.isJavaIdentifierPart
import java.lang.Character.isJavaIdentifierStart
@ -119,6 +121,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces)
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
generateAbstractGetters(schema)
@ -136,6 +139,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray())
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
generateFields(schema)
generateConstructor(schema)

View File

@ -2,12 +2,14 @@ package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
fun testDefaultFactory() = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
fun testDefaultFactoryWithWhitelist() = SerializerFactory(EmptyWhitelist, ClassLoader.getSystemClassLoader())
class TestSerializationOutput(
private val verbose: Boolean,
serializerFactory: SerializerFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
serializerFactory: SerializerFactory = testDefaultFactory())
: SerializationOutput(serializerFactory) {
override fun writeSchema(schema: Schema, data: Data) {

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.serialization.CordaSerializable
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
@ -43,4 +44,22 @@ class DeserializeAndReturnEnvelopeTests {
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("A") })
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("B") })
}
@Test
fun unannotatedInterfaceIsNotInSchema() {
@CordaSerializable
data class Foo(val bar: Int) : Comparable<Foo> {
override fun compareTo(other: Foo): Int = bar.compareTo(other.bar)
}
val a = Foo(123)
val factory = testDefaultFactoryWithWhitelist()
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is Foo)
assertEquals(1, obj.envelope.schema.types.size)
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("Foo") })
assertEquals(null, obj.envelope.schema.types.find { it.name == "java.lang.Comparable<${classTestName("Foo")}>" })
}
}

View File

@ -1,10 +1,12 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.serialization.CordaSerializable
import org.junit.Test
import kotlin.test.*
import net.corda.nodeapi.internal.serialization.carpenter.*
import net.corda.nodeapi.internal.serialization.AllWhitelist
@CordaSerializable
interface I {
fun getName() : String
}
@ -26,7 +28,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
}
val sf1 = testDefaultFactory()
val sf2 = testDefaultFactory()
val sf2 = testDefaultFactoryWithWhitelist() // Deserialize with whitelisting on to check that `CordaSerializable` annotation present.
@Test
fun verySimpleType() {
@ -115,6 +117,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
fun arrayOfTypes() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf("a" to NonNullableField(Int::class.java))))
@CordaSerializable
data class Outer (val a : Array<Any>)
val outer = Outer (arrayOf (
@ -185,6 +188,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
val nestedClass = cc.build (ClassSchema("nestedType",
mapOf("name" to NonNullableField(String::class.java))))
@CordaSerializable
data class outer(val a: Any, val b: Any)
val classInstance = outer (
@ -204,6 +208,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
"v1" to NonNullableField(Int::class.java),
"v2" to NonNullableField(Int::class.java))))
@CordaSerializable
data class outer (val l : List<Any>)
val toSerialise = outer (listOf (
unknownClass.constructors.first().newInstance(1, 2),