mirror of
https://github.com/corda/corda.git
synced 2025-06-14 13:18:18 +00:00
AMQP fixes: Don't attempt to include non-serializable interfaces in schemas (#1281)
* Ignore interfaces that are not serializable * Annotate so test still works. * Make methods accessible to serializer. * Make sure interfaces are annotated in the carpenter. Expand tests to check whitelisting. Object is now whitelisted by default since it has no fields. * Prevented Object from being whitelisted but allow arrays of Objects (i.e. pretty much an untyped array)
This commit is contained in:
@ -25,7 +25,7 @@ class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPSerial
|
|||||||
private val typeName = nameForType(clazz)
|
private val typeName = nameForType(clazz)
|
||||||
|
|
||||||
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
|
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
|
||||||
private val interfaces = interfacesForSerialization(clazz) // TODO maybe this proves too much and we need annotations to restrict.
|
private val interfaces = interfacesForSerialization(clazz, factory) // We restrict to only those annotated or whitelisted
|
||||||
|
|
||||||
internal val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields())
|
internal val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields())
|
||||||
|
|
||||||
|
@ -54,6 +54,7 @@ sealed class PropertySerializer(val name: String, val readMethod: Method, val re
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun make(name: String, readMethod: Method, resolvedType: Type, factory: SerializerFactory): PropertySerializer {
|
fun make(name: String, readMethod: Method, resolvedType: Type, factory: SerializerFactory): PropertySerializer {
|
||||||
|
readMethod.isAccessible = true
|
||||||
if (SerializerFactory.isPrimitive(resolvedType)) {
|
if (SerializerFactory.isPrimitive(resolvedType)) {
|
||||||
return when(resolvedType) {
|
return when(resolvedType) {
|
||||||
Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod)
|
Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod)
|
||||||
|
@ -403,6 +403,6 @@ private fun fingerprintForObject(type: Type, contextType: Type?, alreadySeen: Mu
|
|||||||
propertiesForSerialization(constructorForDeserialization(type), contextType ?: type, factory).fold(hasher.putUnencodedChars(name)) { orig, prop ->
|
propertiesForSerialization(constructorForDeserialization(type), contextType ?: type, factory).fold(hasher.putUnencodedChars(name)) { orig, prop ->
|
||||||
fingerprintForType(prop.resolvedType, type, alreadySeen, orig, factory).putUnencodedChars(prop.name).putUnencodedChars(if (prop.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH)
|
fingerprintForType(prop.resolvedType, type, alreadySeen, orig, factory).putUnencodedChars(prop.name).putUnencodedChars(if (prop.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH)
|
||||||
}
|
}
|
||||||
interfacesForSerialization(type).map { fingerprintForType(it, type, alreadySeen, hasher, factory) }
|
interfacesForSerialization(type, factory).map { fingerprintForType(it, type, alreadySeen, hasher, factory) }
|
||||||
return hasher
|
return hasher
|
||||||
}
|
}
|
||||||
|
@ -103,23 +103,26 @@ private fun propertiesForSerializationFromAbstract(clazz: Class<*>, type: Type,
|
|||||||
return rc
|
return rc
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun interfacesForSerialization(type: Type): List<Type> {
|
internal fun interfacesForSerialization(type: Type, serializerFactory: SerializerFactory): List<Type> {
|
||||||
val interfaces = LinkedHashSet<Type>()
|
val interfaces = LinkedHashSet<Type>()
|
||||||
exploreType(type, interfaces)
|
exploreType(type, interfaces, serializerFactory)
|
||||||
return interfaces.toList()
|
return interfaces.toList()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun exploreType(type: Type?, interfaces: MutableSet<Type>) {
|
private fun exploreType(type: Type?, interfaces: MutableSet<Type>, serializerFactory: SerializerFactory) {
|
||||||
val clazz = type?.asClass()
|
val clazz = type?.asClass()
|
||||||
if (clazz != null) {
|
if (clazz != null) {
|
||||||
if (clazz.isInterface) interfaces += type
|
if (clazz.isInterface) {
|
||||||
|
if(serializerFactory.isNotWhitelisted(clazz)) return // We stop exploring once we reach a branch that has no `CordaSerializable` annotation or whitelisting.
|
||||||
|
else interfaces += type
|
||||||
|
}
|
||||||
for (newInterface in clazz.genericInterfaces) {
|
for (newInterface in clazz.genericInterfaces) {
|
||||||
if (newInterface !in interfaces) {
|
if (newInterface !in interfaces) {
|
||||||
exploreType(resolveTypeVariables(newInterface, type), interfaces)
|
exploreType(resolveTypeVariables(newInterface, type), interfaces, serializerFactory)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val superClass = clazz.genericSuperclass ?: return
|
val superClass = clazz.genericSuperclass ?: return
|
||||||
exploreType(resolveTypeVariables(superClass, type), interfaces)
|
exploreType(resolveTypeVariables(superClass, type), interfaces, serializerFactory)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,10 +4,7 @@ import com.google.common.primitives.Primitives
|
|||||||
import com.google.common.reflect.TypeResolver
|
import com.google.common.reflect.TypeResolver
|
||||||
import net.corda.core.serialization.ClassWhitelist
|
import net.corda.core.serialization.ClassWhitelist
|
||||||
import net.corda.core.serialization.CordaSerializable
|
import net.corda.core.serialization.CordaSerializable
|
||||||
import net.corda.nodeapi.internal.serialization.carpenter.CarpenterSchemas
|
import net.corda.nodeapi.internal.serialization.carpenter.*
|
||||||
import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter
|
|
||||||
import net.corda.nodeapi.internal.serialization.carpenter.MetaCarpenter
|
|
||||||
import net.corda.nodeapi.internal.serialization.carpenter.carpenterSchema
|
|
||||||
import org.apache.qpid.proton.amqp.*
|
import org.apache.qpid.proton.amqp.*
|
||||||
import java.io.NotSerializableException
|
import java.io.NotSerializableException
|
||||||
import java.lang.reflect.GenericArrayType
|
import java.lang.reflect.GenericArrayType
|
||||||
@ -224,7 +221,8 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
|
|||||||
} else {
|
} else {
|
||||||
findCustomSerializer(clazz, declaredType) ?: run {
|
findCustomSerializer(clazz, declaredType) ?: run {
|
||||||
if (type.isArray()) {
|
if (type.isArray()) {
|
||||||
whitelisted(type.componentType())
|
// Allow Object[] since this can be quite common (i.e. an untyped array)
|
||||||
|
if(type.componentType() != Object::class.java) whitelisted(type.componentType())
|
||||||
if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this)
|
if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this)
|
||||||
else ArraySerializer.make(type, this)
|
else ArraySerializer.make(type, this)
|
||||||
} else if (clazz.kotlin.objectInstance != null) {
|
} else if (clazz.kotlin.objectInstance != null) {
|
||||||
@ -258,11 +256,15 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
|
|||||||
|
|
||||||
private fun whitelisted(type: Type) {
|
private fun whitelisted(type: Type) {
|
||||||
val clazz = type.asClass()!!
|
val clazz = type.asClass()!!
|
||||||
if (!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz)) {
|
if (isNotWhitelisted(clazz)) {
|
||||||
throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.")
|
throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ignore SimpleFieldAccess as we add it to anything we build in the carpenter.
|
||||||
|
internal fun isNotWhitelisted(clazz: Class<*>): Boolean = clazz == SimpleFieldAccess::class.java ||
|
||||||
|
(!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz))
|
||||||
|
|
||||||
// Recursively check the class, interfaces and superclasses for our annotation.
|
// Recursively check the class, interfaces and superclasses for our annotation.
|
||||||
internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
|
internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
|
||||||
return type.isAnnotationPresent(CordaSerializable::class.java) ||
|
return type.isAnnotationPresent(CordaSerializable::class.java) ||
|
||||||
|
@ -10,7 +10,7 @@ import java.lang.reflect.Type
|
|||||||
*/
|
*/
|
||||||
class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: SerializerFactory) : AMQPSerializer<Any> {
|
class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: SerializerFactory) : AMQPSerializer<Any> {
|
||||||
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
|
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
|
||||||
private val interfaces = interfacesForSerialization(type)
|
private val interfaces = interfacesForSerialization(type, factory)
|
||||||
|
|
||||||
private fun generateProvides(): List<String> = interfaces.map { it.typeName }
|
private fun generateProvides(): List<String> = interfaces.map { it.typeName }
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package net.corda.nodeapi.internal.serialization.carpenter
|
package net.corda.nodeapi.internal.serialization.carpenter
|
||||||
|
|
||||||
|
import net.corda.core.serialization.CordaSerializable
|
||||||
import org.objectweb.asm.ClassWriter
|
import org.objectweb.asm.ClassWriter
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.MethodVisitor
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.Opcodes.*
|
||||||
|
import org.objectweb.asm.Type
|
||||||
|
|
||||||
import java.lang.Character.isJavaIdentifierPart
|
import java.lang.Character.isJavaIdentifierPart
|
||||||
import java.lang.Character.isJavaIdentifierStart
|
import java.lang.Character.isJavaIdentifierStart
|
||||||
@ -119,6 +121,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
|
|||||||
|
|
||||||
with(cw) {
|
with(cw) {
|
||||||
visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces)
|
visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces)
|
||||||
|
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
|
||||||
|
|
||||||
generateAbstractGetters(schema)
|
generateAbstractGetters(schema)
|
||||||
|
|
||||||
@ -136,6 +139,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
|
|||||||
|
|
||||||
with(cw) {
|
with(cw) {
|
||||||
visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray())
|
visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray())
|
||||||
|
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
|
||||||
|
|
||||||
generateFields(schema)
|
generateFields(schema)
|
||||||
generateConstructor(schema)
|
generateConstructor(schema)
|
||||||
|
@ -2,12 +2,14 @@ package net.corda.nodeapi.internal.serialization.amqp
|
|||||||
|
|
||||||
import org.apache.qpid.proton.codec.Data
|
import org.apache.qpid.proton.codec.Data
|
||||||
import net.corda.nodeapi.internal.serialization.AllWhitelist
|
import net.corda.nodeapi.internal.serialization.AllWhitelist
|
||||||
|
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
|
||||||
|
|
||||||
fun testDefaultFactory() = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
|
fun testDefaultFactory() = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
|
||||||
|
fun testDefaultFactoryWithWhitelist() = SerializerFactory(EmptyWhitelist, ClassLoader.getSystemClassLoader())
|
||||||
|
|
||||||
class TestSerializationOutput(
|
class TestSerializationOutput(
|
||||||
private val verbose: Boolean,
|
private val verbose: Boolean,
|
||||||
serializerFactory: SerializerFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
serializerFactory: SerializerFactory = testDefaultFactory())
|
||||||
: SerializationOutput(serializerFactory) {
|
: SerializationOutput(serializerFactory) {
|
||||||
|
|
||||||
override fun writeSchema(schema: Schema, data: Data) {
|
override fun writeSchema(schema: Schema, data: Data) {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package net.corda.nodeapi.internal.serialization.amqp
|
package net.corda.nodeapi.internal.serialization.amqp
|
||||||
|
|
||||||
|
import net.corda.core.serialization.CordaSerializable
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertNotEquals
|
import kotlin.test.assertNotEquals
|
||||||
@ -43,4 +44,22 @@ class DeserializeAndReturnEnvelopeTests {
|
|||||||
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("A") })
|
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("A") })
|
||||||
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("B") })
|
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("B") })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun unannotatedInterfaceIsNotInSchema() {
|
||||||
|
@CordaSerializable
|
||||||
|
data class Foo(val bar: Int) : Comparable<Foo> {
|
||||||
|
override fun compareTo(other: Foo): Int = bar.compareTo(other.bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
val a = Foo(123)
|
||||||
|
val factory = testDefaultFactoryWithWhitelist()
|
||||||
|
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
|
||||||
|
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
|
||||||
|
|
||||||
|
assertTrue(obj.obj is Foo)
|
||||||
|
assertEquals(1, obj.envelope.schema.types.size)
|
||||||
|
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("Foo") })
|
||||||
|
assertEquals(null, obj.envelope.schema.types.find { it.name == "java.lang.Comparable<${classTestName("Foo")}>" })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package net.corda.nodeapi.internal.serialization.amqp
|
package net.corda.nodeapi.internal.serialization.amqp
|
||||||
|
|
||||||
|
import net.corda.core.serialization.CordaSerializable
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
import net.corda.nodeapi.internal.serialization.carpenter.*
|
import net.corda.nodeapi.internal.serialization.carpenter.*
|
||||||
import net.corda.nodeapi.internal.serialization.AllWhitelist
|
import net.corda.nodeapi.internal.serialization.AllWhitelist
|
||||||
|
|
||||||
|
@CordaSerializable
|
||||||
interface I {
|
interface I {
|
||||||
fun getName() : String
|
fun getName() : String
|
||||||
}
|
}
|
||||||
@ -26,7 +28,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val sf1 = testDefaultFactory()
|
val sf1 = testDefaultFactory()
|
||||||
val sf2 = testDefaultFactory()
|
val sf2 = testDefaultFactoryWithWhitelist() // Deserialize with whitelisting on to check that `CordaSerializable` annotation present.
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun verySimpleType() {
|
fun verySimpleType() {
|
||||||
@ -115,6 +117,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
|
|||||||
fun arrayOfTypes() {
|
fun arrayOfTypes() {
|
||||||
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf("a" to NonNullableField(Int::class.java))))
|
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf("a" to NonNullableField(Int::class.java))))
|
||||||
|
|
||||||
|
@CordaSerializable
|
||||||
data class Outer (val a : Array<Any>)
|
data class Outer (val a : Array<Any>)
|
||||||
|
|
||||||
val outer = Outer (arrayOf (
|
val outer = Outer (arrayOf (
|
||||||
@ -185,6 +188,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
|
|||||||
val nestedClass = cc.build (ClassSchema("nestedType",
|
val nestedClass = cc.build (ClassSchema("nestedType",
|
||||||
mapOf("name" to NonNullableField(String::class.java))))
|
mapOf("name" to NonNullableField(String::class.java))))
|
||||||
|
|
||||||
|
@CordaSerializable
|
||||||
data class outer(val a: Any, val b: Any)
|
data class outer(val a: Any, val b: Any)
|
||||||
|
|
||||||
val classInstance = outer (
|
val classInstance = outer (
|
||||||
@ -204,6 +208,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
|
|||||||
"v1" to NonNullableField(Int::class.java),
|
"v1" to NonNullableField(Int::class.java),
|
||||||
"v2" to NonNullableField(Int::class.java))))
|
"v2" to NonNullableField(Int::class.java))))
|
||||||
|
|
||||||
|
@CordaSerializable
|
||||||
data class outer (val l : List<Any>)
|
data class outer (val l : List<Any>)
|
||||||
val toSerialise = outer (listOf (
|
val toSerialise = outer (listOf (
|
||||||
unknownClass.constructors.first().newInstance(1, 2),
|
unknownClass.constructors.first().newInstance(1, 2),
|
||||||
|
Reference in New Issue
Block a user