Add support for serialising enum types - Part 1

Part 2 will address the carpenting of enum classes
This commit is contained in:
Katelyn Baker 2017-08-24 13:48:07 +01:00
parent f0b2b0a566
commit ed2b2b02ca
12 changed files with 353 additions and 71 deletions

View File

@ -25,7 +25,7 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
private val objectHistory: MutableList<Any> = mutableListOf()
internal companion object {
val BYTES_NEEDED_TO_PEEK: Int = 23
private val BYTES_NEEDED_TO_PEEK: Int = 23
fun peekSize(bytes: ByteArray): Int {
// There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor
@ -57,7 +57,6 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
inline internal fun <reified T : Any> deserializeAndReturnEnvelope(bytes: SerializedBytes<T>): ObjectAndEnvelope<T> =
deserializeAndReturnEnvelope(bytes, T::class.java)
@Throws(NotSerializableException::class)
private fun getEnvelope(bytes: ByteSequence): Envelope {
// Check that the lead bytes match expected header
@ -94,20 +93,16 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
* be deserialized and a schema describing the types of the objects.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>): T {
return des {
val envelope = getEnvelope(bytes)
clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz))
}
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>): T = des {
val envelope = getEnvelope(bytes)
clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz))
}
@Throws(NotSerializableException::class)
internal fun <T : Any> deserializeAndReturnEnvelope(bytes: SerializedBytes<T>, clazz: Class<T>): ObjectAndEnvelope<T> {
return des {
val envelope = getEnvelope(bytes)
// Now pick out the obj and schema from the envelope.
ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope)
}
fun <T : Any> deserializeAndReturnEnvelope(bytes: SerializedBytes<T>, clazz: Class<T>): ObjectAndEnvelope<T> = des {
val envelope = getEnvelope(bytes)
// Now pick out the obj and schema from the envelope.
ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope)
}
internal fun readObjectOrNull(obj: Any?, schema: Schema, type: Type): Any? {
@ -115,36 +110,36 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
}
internal fun readObject(obj: Any, schema: Schema, type: Type): Any =
if (obj is DescribedType && ReferencedObject.DESCRIPTOR == obj.descriptor) {
// It must be a reference to an instance that has already been read, cheaply and quickly returning it by reference.
val objectIndex = (obj.described as UnsignedInteger).toInt()
if (objectIndex !in 0..objectHistory.size)
throw NotSerializableException("Retrieval of existing reference failed. Requested index $objectIndex " +
"is outside of the bounds for the list of size: ${objectHistory.size}")
if (obj is DescribedType && ReferencedObject.DESCRIPTOR == obj.descriptor) {
// It must be a reference to an instance that has already been read, cheaply and quickly returning it by reference.
val objectIndex = (obj.described as UnsignedInteger).toInt()
if (objectIndex !in 0..objectHistory.size)
throw NotSerializableException("Retrieval of existing reference failed. Requested index $objectIndex " +
"is outside of the bounds for the list of size: ${objectHistory.size}")
val objectRetrieved = objectHistory[objectIndex]
if (!objectRetrieved::class.java.isSubClassOf(type.asClass()!!))
throw NotSerializableException("Existing reference type mismatch. Expected: '$type', found: '${objectRetrieved::class.java}'")
objectRetrieved
}
else {
val objectRead = when (obj) {
is DescribedType -> {
// Look up serializer in factory by descriptor
val serializer = serializerFactory.get(obj.descriptor, schema)
if (serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) })
throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " +
"expected to be of type $type but was ${serializer.type}")
serializer.readObject(obj.described, schema, this)
val objectRetrieved = objectHistory[objectIndex]
if (!objectRetrieved::class.java.isSubClassOf(type))
throw NotSerializableException("Existing reference type mismatch. Expected: '$type', found: '${objectRetrieved::class.java}'")
objectRetrieved
} else {
val objectRead = when (obj) {
is DescribedType -> {
// Look up serializer in factory by descriptor
val serializer = serializerFactory.get(obj.descriptor, schema)
if (serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) })
throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " +
"expected to be of type $type but was ${serializer.type}")
serializer.readObject(obj.described, schema, this)
}
is Binary -> obj.array
else -> obj // this will be the case for primitive types like [boolean] et al.
}
is Binary -> obj.array
else -> obj // this will be the case for primitive types like [boolean] et al.
// Store the reference in case we need it later on.
// Skip for primitive types as they are too small and overhead of referencing them will be much higher than their content
if (type.asClass()?.isPrimitive != true) objectHistory.add(objectRead)
objectRead
}
// Store the reference in case we need it later on.
// Skip for primitive types as they are too small and overhead of referencing them will be much higher than their content
if (suitableForObjectReference(objectRead.javaClass)) objectHistory.add(objectRead)
objectRead
}
/**
* TODO: Currently performs rather basic checks aimed in particular at [java.util.List<Command<?>>] and
@ -152,5 +147,5 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
* In the future tighter control might be needed
*/
private fun Type.materiallyEquivalentTo(that: Type): Boolean =
asClass() == that.asClass() && that is ParameterizedType
}
asClass() == that.asClass() && that is ParameterizedType
}

View File

@ -53,6 +53,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p
var typeStart = 0
var needAType = true
var skippingWhitespace = false
while (pos < params.length) {
if (params[pos] == '<') {
val typeEnd = pos++
@ -102,7 +103,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p
} else if (!skippingWhitespace && (params[pos] == '.' || params[pos].isJavaIdentifierPart())) {
pos++
} else {
throw NotSerializableException("Invalid character in middle of type: ${params[pos]}")
throw NotSerializableException("Invalid character ${params[pos]} in middle of type $params at idx $pos")
}
}
}

View File

@ -0,0 +1,47 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
import java.io.NotSerializableException
class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
private val typeNotation: TypeNotation
init {
typeNotation = RestrictedType(
SerializerFactory.nameForType(declaredType),
null, emptyList(), "enum", Descriptor(typeDescriptor, null),
declaredClass.enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map {
Choice(it.first.toString(), it.second.toString())
})
}
override fun writeClassInfo(output: SerializationOutput) {
output.writeTypeNotations(typeNotation)
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
val enumName = (obj as List<*>)[0] as String
val enumOrd = obj[1] as Int
val fromOrd = type.asClass()!!.enumConstants[enumOrd]
if (enumName != fromOrd?.toString()) {
throw NotSerializableException("Deserializing obj as enum $type with value $enumName.$enumOrd but "
+ "ordinality has changed")
}
return fromOrd
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
if (obj !is Enum<*>) throw NotSerializableException("Serializing $obj as enum when it isn't")
data.withDescribed(typeNotation.descriptor) {
withList {
data.putString(obj.name)
data.putInt(obj.ordinal)
}
}
}
}

View File

@ -13,7 +13,7 @@ import kotlin.collections.map
/**
* Serialization / deserialization of certain supported [Map] types.
*/
class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer<Any> {
class MapSerializer(private val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString())
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"

View File

@ -252,7 +252,12 @@ data class CompositeType(override val name: String, override val label: String?,
}
}
data class RestrictedType(override val name: String, override val label: String?, override val provides: List<String>, val source: String, override val descriptor: Descriptor, val choices: List<Choice>) : TypeNotation() {
data class RestrictedType(override val name: String,
override val label: String?,
override val provides: List<String>,
val source: String,
override val descriptor: Descriptor,
val choices: List<Choice>) : TypeNotation() {
companion object : DescribedTypeConstructor<RestrictedType> {
val DESCRIPTOR = DescriptorRegistry.RESTRICTED_TYPE.amqpDescriptor
@ -290,6 +295,9 @@ data class RestrictedType(override val name: String, override val label: String?
}
sb.append(">\n")
sb.append(" $descriptor\n")
choices.forEach {
sb.append(" $it\n")
}
sb.append("</type>")
return sb.toString()
}
@ -403,14 +411,13 @@ private fun fingerprintForType(type: Type, contextType: Type?, alreadySeen: Muta
if (type is SerializerFactory.AnyType) {
hasher.putUnencodedChars(ANY_TYPE_HASH)
} else if (type is Class<*>) {
if (type.isArray) {
fingerprintForType(type.componentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH)
} else if (SerializerFactory.isPrimitive(type)) {
hasher.putUnencodedChars(type.name)
} else if (isCollectionOrMap(type)) {
hasher.putUnencodedChars(type.name)
} else {
hasher.fingerprintWithCustomSerializerOrElse(factory, type, type) {
when {
type.isArray -> fingerprintForType(type.componentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH)
SerializerFactory.isPrimitive(type) ||
isCollectionOrMap(type) ||
type.isEnum -> hasher.putUnencodedChars(type.name)
else ->
hasher.fingerprintWithCustomSerializerOrElse(factory, type, type) {
if (type.kotlin.objectInstance != null) {
// TODO: name collision is too likely for kotlin objects, we need to introduce some reference
// to the CorDapp but maybe reference to the JAR in the short term.

View File

@ -21,7 +21,6 @@ data class schemaAndDescriptor (val schema: Schema, val typeDescriptor: Any)
/**
* Factory of serializers designed to be shared across threads and invocations.
*/
// TODO: enums
// TODO: object references - need better fingerprinting?
// TODO: class references? (e.g. cheat with repeated descriptors using a long encoding, like object ref proposal)
// TODO: Inner classes etc. Should we allow? Currently not considered.
@ -66,18 +65,20 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType
val serializer = if (Collection::class.java.isAssignableFrom(declaredClass)) {
serializersByType.computeIfAbsent(declaredType) {
CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(
declaredClass, arrayOf(AnyType), null), this)
val serializer = when {
(Collection::class.java.isAssignableFrom(declaredClass)) -> { serializersByType.computeIfAbsent(declaredType) {
CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(
declaredClass, arrayOf(AnyType), null), this)
}
}
} else if (Map::class.java.isAssignableFrom(declaredClass)) {
serializersByType.computeIfAbsent(declaredClass) {
Map::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) {
makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(
declaredClass, arrayOf(AnyType, AnyType), null))
}
} else {
makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType)
Enum::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) {
EnumSerializer(actualType, actualClass ?: declaredClass, this)
}
else -> makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType)
}
serializersByDescriptor.putIfAbsent(serializer.typeDescriptor, serializer)
@ -248,8 +249,9 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
}
internal fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer<Any>? {
// e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is AbstractMap, only Map.
// Otherwise it needs to inject additional schema for a RestrictedType source of the super type. Could be done, but do we need it?
// e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is
// AbstractMap, only Map. Otherwise it needs to inject additional schema for a RestrictedType source of the
// super type. Could be done, but do we need it?
for (customSerializer in customSerializers) {
if (customSerializer.isSerializerFor(clazz)) {
val declaredSuperClass = declaredType.asClass()?.superclass
@ -258,7 +260,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
} else {
// Make a subclass serializer for the subclass and return that...
@Suppress("UNCHECKED_CAST")
return CustomSerializer.SubClass<Any>(clazz, customSerializer as CustomSerializer<Any>)
return CustomSerializer.SubClass(clazz, customSerializer as CustomSerializer<Any>)
}
}
}
@ -277,7 +279,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) {
(!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz))
// Recursively check the class, interfaces and superclasses for our annotation.
internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
private fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
return type.isAnnotationPresent(CordaSerializable::class.java) ||
type.interfaces.any { hasAnnotationInHierarchy(it) }
|| (type.superclass != null && hasAnnotationInHierarchy(type.superclass))

View File

@ -0,0 +1,36 @@
package net.corda.nodeapi.internal.serialization.amqp;
import org.junit.Test;
import net.corda.nodeapi.internal.serialization.AllWhitelist;
import net.corda.core.serialization.SerializedBytes;
import java.io.NotSerializableException;
public class JavaSerialiseEnumTests {
public enum Bras {
TSHIRT, UNDERWIRE, PUSHUP, BRALETTE, STRAPLESS, SPORTS, BACKLESS, PADDED
}
private static class Bra {
private final Bras bra;
private Bra(Bras bra) {
this.bra = bra;
}
public Bras getBra() {
return this.bra;
}
}
@Test
public void testJavaConstructorAnnotations() throws NotSerializableException {
Bra bra = new Bra(Bras.UNDERWIRE);
SerializerFactory factory1 = new SerializerFactory(AllWhitelist.INSTANCE, ClassLoader.getSystemClassLoader());
SerializationOutput ser = new SerializationOutput(factory1);
SerializedBytes<Object> bytes = ser.serialize(bra);
}
}

View File

@ -25,15 +25,17 @@ public class JavaSerializationOutputTests {
}
@ConstructorForDeserialization
public Foo(String fred, int count) {
private Foo(String fred, int count) {
this.bob = fred;
this.count = count;
}
@SuppressWarnings("unused")
public String getFred() {
return bob;
}
@SuppressWarnings("unused")
public int getCount() {
return count;
}
@ -61,15 +63,17 @@ public class JavaSerializationOutputTests {
private final String bob;
private final int count;
public UnAnnotatedFoo(String fred, int count) {
private UnAnnotatedFoo(String fred, int count) {
this.bob = fred;
this.count = count;
}
@SuppressWarnings("unused")
public String getFred() {
return bob;
}
@SuppressWarnings("unused")
public int getCount() {
return count;
}
@ -97,7 +101,7 @@ public class JavaSerializationOutputTests {
private final String fred;
private final Integer count;
public BoxedFoo(String fred, Integer count) {
private BoxedFoo(String fred, Integer count) {
this.fred = fred;
this.count = count;
}
@ -134,7 +138,7 @@ public class JavaSerializationOutputTests {
private final String fred;
private final Integer count;
public BoxedFooNotNull(String fred, Integer count) {
private BoxedFooNotNull(String fred, Integer count) {
this.fred = fred;
this.count = count;
}

View File

@ -6,7 +6,6 @@ import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
import java.io.NotSerializableException
fun testDefaultFactory() = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
fun testDefaultFactoryWithWhitelist() = SerializerFactory(EmptyWhitelist, ClassLoader.getSystemClassLoader())

View File

@ -0,0 +1,191 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import java.time.DayOfWeek
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import java.io.File
import java.io.NotSerializableException
import net.corda.core.serialization.SerializedBytes
class EnumTests {
enum class Bras {
TSHIRT, UNDERWIRE, PUSHUP, BRALETTE, STRAPLESS, SPORTS, BACKLESS, PADDED
}
// The state of the OldBras enum when the tests in changedEnum1 were serialised
// - use if the test file needs regenerating
//enum class OldBras {
// TSHIRT, UNDERWIRE, PUSHUP, BRALETTE
//}
// the new state, SPACER has been added to change the ordinality
enum class OldBras {
SPACER, TSHIRT, UNDERWIRE, PUSHUP, BRALETTE
}
// The state of the OldBras2 enum when the tests in changedEnum2 were serialised
// - use if the test file needs regenerating
//enum class OldBras2 {
// TSHIRT, UNDERWIRE, PUSHUP, BRALETTE
//}
// the new state, note in the test we serialised with value UNDERWIRE so the spacer
// occuring after this won't have changed the ordinality of our serialised value
// and thus should still be deserialisable
enum class OldBras2 {
TSHIRT, UNDERWIRE, PUSHUP, SPACER, BRALETTE, SPACER2
}
enum class BrasWithInit (val someList: List<Int>) {
TSHIRT(emptyList()),
UNDERWIRE(listOf(1, 2, 3)),
PUSHUP(listOf(100, 200)),
BRALETTE(emptyList())
}
private val brasTestName = "${this.javaClass.name}\$Bras"
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
*/
private const val VERBOSE = false
}
@Suppress("NOTHING_TO_INLINE")
inline private fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz"
private val sf1 = testDefaultFactory()
@Test
fun serialiseSimpleTest() {
data class C(val c: Bras)
val schema = TestSerializationOutput(VERBOSE, sf1).serializeAndReturnSchema(C(Bras.UNDERWIRE)).schema
assertEquals(2, schema.types.size)
val schema_c = schema.types.find { it.name == classTestName("C") } as CompositeType
val schema_bras = schema.types.find { it.name == brasTestName } as RestrictedType
assertNotNull(schema_c)
assertNotNull(schema_bras)
assertEquals(1, schema_c.fields.size)
assertEquals("c", schema_c.fields.first().name)
assertEquals(brasTestName, schema_c.fields.first().type)
assertEquals(8, schema_bras.choices.size)
Bras.values().forEach {
val bra = it
assertNotNull (schema_bras.choices.find { it.name == bra.name })
}
}
@Test
fun deserialiseSimpleTest() {
data class C(val c: Bras)
val objAndEnvelope = DeserializationInput(sf1).deserializeAndReturnEnvelope(
TestSerializationOutput(VERBOSE, sf1).serialize(C(Bras.UNDERWIRE)))
val obj = objAndEnvelope.obj
val schema = objAndEnvelope.envelope.schema
assertEquals(2, schema.types.size)
val schema_c = schema.types.find { it.name == classTestName("C") } as CompositeType
val schema_bras = schema.types.find { it.name == brasTestName } as RestrictedType
assertEquals(1, schema_c.fields.size)
assertEquals("c", schema_c.fields.first().name)
assertEquals(brasTestName, schema_c.fields.first().type)
assertEquals(8, schema_bras.choices.size)
Bras.values().forEach {
val bra = it
assertNotNull (schema_bras.choices.find { it.name == bra.name })
}
// Test the actual deserialised object
assertEquals(obj.c, Bras.UNDERWIRE)
}
@Test
fun multiEnum() {
data class Support (val top: Bras, val day : DayOfWeek)
data class WeeklySupport (val tops: List<Support>)
val week = WeeklySupport (listOf(
Support (Bras.PUSHUP, DayOfWeek.MONDAY),
Support (Bras.UNDERWIRE, DayOfWeek.WEDNESDAY),
Support (Bras.PADDED, DayOfWeek.SUNDAY)))
val obj = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(week))
assertEquals(week.tops[0].top, obj.tops[0].top)
assertEquals(week.tops[0].day, obj.tops[0].day)
assertEquals(week.tops[1].top, obj.tops[1].top)
assertEquals(week.tops[1].day, obj.tops[1].day)
assertEquals(week.tops[2].top, obj.tops[2].top)
assertEquals(week.tops[2].day, obj.tops[2].day)
}
@Test
fun enumWithInit() {
data class C(val c: BrasWithInit)
val c = C (BrasWithInit.PUSHUP)
val obj = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(c))
assertEquals(c.c, obj.c)
}
@Test(expected = NotSerializableException::class)
fun changedEnum1() {
val path = EnumTests::class.java.getResource("EnumTests.changedEnum1")
val f = File(path.toURI())
data class C (val a: OldBras)
// Original version of the class for the serialised version of this class
//
// val a = OldBras.TSHIRT
// val sc = SerializationOutput(sf1).serialize(C(a))
// f.writeBytes(sc.bytes)
// println(path)
val sc2 = f.readBytes()
// we expect this to throw
DeserializationInput(sf1).deserialize(SerializedBytes<C>(sc2))
}
@Test
fun changedEnum2() {
val path = EnumTests::class.java.getResource("EnumTests.changedEnum2")
val f = File(path.toURI())
data class C (val a: OldBras2)
// DO NOT CHANGE THIS, it's important we serialise with a value that doesn't
// change position in the upated enum class
val a = OldBras2.UNDERWIRE
// Original version of the class for the serialised version of this class
//
// val sc = SerializationOutput(sf1).serialize(C(a))
// f.writeBytes(sc.bytes)
// println(path)
val sc2 = f.readBytes()
// we expect this to throw
val obj = DeserializationInput(sf1).deserialize(SerializedBytes<C>(sc2))
assertEquals(a, obj.a)
}
}