diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolutionSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolutionSerializer.kt index 25413d650c..ec8a61c793 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolutionSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolutionSerializer.kt @@ -6,6 +6,7 @@ import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException import java.lang.UnsupportedOperationException import java.lang.reflect.Type +import java.util.* /** * Used whenever a deserialized enums fingerprint doesn't match the fingerprint of the generated @@ -62,11 +63,13 @@ class EnumEvolutionSerializer( fun make(old: RestrictedType, new: AMQPSerializer, factory: SerializerFactory, - transformsFromBlob: TransformsSchema): AMQPSerializer { - - val wireTransforms = transformsFromBlob.types[old.name] + schemas: SerializationSchemas): AMQPSerializer { + val wireTransforms = schemas.transforms.types[old.name] ?: EnumMap>(TransformTypes::class.java) val localTransforms = TransformsSchema.get(old.name, factory) - val transforms = if (wireTransforms?.size ?: -1 > localTransforms.size) wireTransforms!! else localTransforms + + // remember, the longer the list the newer we're assuming the transform set it as we assume + // evolution annotations are never removed, only added to + val transforms = if (wireTransforms.size > localTransforms.size) wireTransforms else localTransforms // if either of these isn't of the cast type then something has gone terribly wrong // elsewhere in the code @@ -84,8 +87,12 @@ class EnumEvolutionSerializer( val rules: MutableMap = mutableMapOf() rules.putAll(defaultRules?.associateBy({ it.new }, { it.old }) ?: emptyMap()) - rules.putAll(renameRules?.associateBy({ it.to }, { it.from }) ?: emptyMap()) + val renameRulesMap = renameRules?.associateBy({ it.to }, { it.from }) ?: emptyMap() + rules.putAll(renameRulesMap) + // take out set of all possible constants and build a map from those to the + // existing constants applying the rename and defaulting rules as defined + // in the schema while (conversions.filterNot { it.value in localValues }.isNotEmpty()) { conversions.mapInPlace { rules[it] ?: it } } @@ -93,8 +100,19 @@ class EnumEvolutionSerializer( // you'd think this was overkill to get access to the ordinal values for each constant but it's actually // rather tricky when you don't have access to the actual type, so this is a nice way to be able // to precompute and pass to the actual object - return EnumEvolutionSerializer(new.type, factory, conversions, - localValues.mapIndexed { i, s -> Pair (s, i)}.toMap()) + val ordinals = localValues.mapIndexed { i, s -> Pair(s, i) }.toMap() + + // create a mapping between the ordinal value and the name as it was serialised converted + // to the name as it exists. We want to test any new constants have been added to the end + // of the enum class + val serialisedOrds = ((schemas.schema.types.find { it.name == old.name } as RestrictedType).choices + .associateBy ({ it.value.toInt() }, { conversions[it.name] })) + + if (ordinals.filterNot { serialisedOrds[it.value] == it.key }.isNotEmpty()) { + throw NotSerializableException("Constants have been reordered, additions must be appended to the end") + } + + return EnumEvolutionSerializer(new.type, factory, conversions, ordinals) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index 764f0b5458..39abcc58f5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -46,11 +46,11 @@ open class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { private fun getEvolutionSerializer( typeNotation: TypeNotation, newSerializer: AMQPSerializer, - transforms: TransformsSchema): AMQPSerializer { + schemas: SerializationSchemas): AMQPSerializer { return serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { when (typeNotation) { is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, this) - is RestrictedType -> EnumEvolutionSerializer.make(typeNotation, newSerializer, this, transforms) + is RestrictedType -> EnumEvolutionSerializer.make(typeNotation, newSerializer, this, schemas) } } } @@ -210,7 +210,7 @@ open class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { // doesn't match that of the serialised object then we are dealing with different // instance of the class, as such we need to build an EvolutionSerialiser if (serialiser.typeDescriptor != typeNotation.descriptor.name) { - getEvolutionSerializer(typeNotation, serialiser, schemaAndDescriptor.schemas.transforms) + getEvolutionSerializer(typeNotation, serialiser, schemaAndDescriptor.schemas) } } catch (e: ClassNotFoundException) { if (sentinel) throw e diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TansformTypes.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TansformTypes.kt index e21498f7b0..0003a48ba2 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TansformTypes.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TansformTypes.kt @@ -28,14 +28,14 @@ enum class TransformTypes(val build: (Annotation) -> Transform) : DescribedType Unknown({ UnknownTransform() }) { override fun getDescriptor(): Any = DESCRIPTOR override fun getDescribed(): Any = ordinal - override fun validate(l : List, constants: Set) { } + override fun validate(l : List, constants: Map) { } }, EnumDefault({ a -> EnumDefaultSchemaTransform((a as CordaSerializationTransformEnumDefault).old, a.new) }) { override fun getDescriptor(): Any = DESCRIPTOR override fun getDescribed(): Any = ordinal /** - * Validates a list of constant additions to an enumerated types, to be valid a default (the value + * Validates a list of constant additions to an enumerated type. To be valid a default (the value * that should be used when we cannot use the new value) must refer to a constant that exists in the * enum class as it exists now and it cannot refer to itself. * @@ -43,8 +43,12 @@ enum class TransformTypes(val build: (Annotation) -> Transform) : DescribedType * existing value * @param constants The list of enum constants on the type the transforms are being applied to */ - override fun validate(l : List, constants: Set) { - uncheckedCast, List>(l).forEach { + override fun validate(list : List, constants: Map) { + uncheckedCast, List>(list).forEach { + if (!constants.contains(it.new)) { + throw NotSerializableException("Unknown enum constant ${it.new}") + } + if (!constants.contains(it.old)) { throw NotSerializableException( "Enum extension defaults must be to a valid constant: ${it.new} -> ${it.old}. ${it.old} " + @@ -54,6 +58,12 @@ enum class TransformTypes(val build: (Annotation) -> Transform) : DescribedType if (it.old == it.new) { throw NotSerializableException("Enum extension ${it.new} cannot default to itself") } + + if (constants[it.old]!! >= constants[it.new]!!) { + throw NotSerializableException( + "Enum extensions must default to older constants. ${it.new}[${constants[it.new]}] " + + "defaults to ${it.old}[${constants[it.old]}] which is greater") + } } } }, @@ -70,7 +80,7 @@ enum class TransformTypes(val build: (Annotation) -> Transform) : DescribedType * and old values * @param constants The list of enum constants on the type the transforms are being applied to */ - override fun validate(l : List, constants: Set) { + override fun validate(l : List, constants: Map) { object : Any() { val from : MutableSet = mutableSetOf() val to : MutableSet = mutableSetOf() }.apply { @@ -94,7 +104,7 @@ enum class TransformTypes(val build: (Annotation) -> Transform) : DescribedType //} ; - abstract fun validate(l: List, constants: Set) + abstract fun validate(l: List, constants: Map) companion object : DescribedTypeConstructor { val DESCRIPTOR = AMQPDescriptorRegistry.TRANSFORM_ELEMENT_KEY.amqpDescriptor diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt index c88addacaa..378675b84e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/TransformsSchema.kt @@ -225,7 +225,7 @@ data class TransformsSchema(val types: Map Pair(s.toString(), i) }.toMap()) } } } catch (_: ClassNotFoundException) { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.kt index 2b7930c707..b2939d7fa9 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.kt @@ -355,4 +355,60 @@ class EnumEvolveTests { load (stage4Resources).forEach { assertEquals(it.second, it.first.e) } load (stage5Resources).forEach { assertEquals(it.second, it.first.e) } } + + @CordaSerializationTransformEnumDefault(old = "A", new = "F") + enum class BadNewValue { A, B, C, D } + + @Test + fun badNewValue() { + val sf = testDefaultFactory() + + data class C (val e : BadNewValue) + + Assertions.assertThatThrownBy { + SerializationOutput(sf).serialize(C(BadNewValue.A)) + }.isInstanceOf(NotSerializableException::class.java) + } + + @CordaSerializationTransformEnumDefaults( + CordaSerializationTransformEnumDefault(new = "D", old = "E"), + CordaSerializationTransformEnumDefault(new = "E", old = "A") + ) + enum class OutOfOrder { A, B, C, D, E} + + @Test + fun outOfOrder() { + val sf = testDefaultFactory() + + data class C (val e : OutOfOrder) + + Assertions.assertThatThrownBy { + SerializationOutput(sf).serialize(C(OutOfOrder.A)) + }.isInstanceOf(NotSerializableException::class.java) + } + + // class as it existed as it was serialized + // + // enum class ChangedOrdinality { A, B, C } + // + // class as it exists for the tests + @CordaSerializationTransformEnumDefault("D", "A") + enum class ChangedOrdinality { A, B, D, C } + + @Test + fun changedOrdinality() { + val resource = "${javaClass.simpleName}.${testName()}" + val sf = testDefaultFactory() + + data class C(val e: ChangedOrdinality) + + // Uncomment to re-generate test files, needs to be done in three stages + // File(URI("$localPath/$resource")).writeBytes( + // SerializationOutput(sf).serialize(C(ChangedOrdinality.A)).bytes) + + Assertions.assertThatThrownBy { + DeserializationInput(sf).deserialize(SerializedBytes( + File(EvolvabilityTests::class.java.getResource(resource).toURI()).readBytes())) + }.isInstanceOf(NotSerializableException::class.java) + } } diff --git a/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.changedOrdinality b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.changedOrdinality new file mode 100644 index 0000000000..af084de97a Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumEvolveTests.changedOrdinality differ