From 107819f5b53ceaff17b9bc291395054394d37d83 Mon Sep 17 00:00:00 2001 From: Chris Rankin Date: Thu, 30 Apr 2020 14:59:10 +0100 Subject: [PATCH] CORDA-3745: Modify DJVM serializers to support Enum Evolution. (#6189) --- .../DeterministicVerifierFactoryService.kt | 8 + .../djvm/SandboxSerializerFactoryFactory.kt | 3 +- .../corda/serialization/djvm/Serialization.kt | 8 +- .../djvm/DeserializeEnumWithEvolutionTest.kt | 155 ++++++++++++++++++ .../djvm/SafeDeserialisationTest.kt | 33 +--- .../net/corda/serialization/djvm/TestBase.kt | 8 + .../corda/serialization/djvm/TestHelpers.kt | 28 ++++ .../internal/amqp/EnumEvolutionSerializer.kt | 4 +- .../amqp/EvolutionSerializerFactory.kt | 5 +- .../internal/amqp/SerializerFactoryBuilder.kt | 9 +- .../internal/amqp/TransformsSchema.kt | 22 ++- .../WhitelistBasedTypeModelConfiguration.kt | 3 +- .../model/LocalTypeInformationBuilder.kt | 8 +- .../internal/model/LocalTypeModel.kt | 3 +- 14 files changed, 240 insertions(+), 57 deletions(-) create mode 100644 serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeEnumWithEvolutionTest.kt create mode 100644 serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestHelpers.kt diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt index 333b680bf5..c1a18936a5 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt @@ -4,6 +4,10 @@ import net.corda.core.internal.BasicVerifier import net.corda.core.internal.Verifier import net.corda.core.serialization.ConstructorForDeserialization import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.CordaSerializationTransformEnumDefault +import net.corda.core.serialization.CordaSerializationTransformEnumDefaults +import net.corda.core.serialization.CordaSerializationTransformRename +import net.corda.core.serialization.CordaSerializationTransformRenames import net.corda.core.serialization.DeprecatedConstructorForDeserialization import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.LedgerTransaction @@ -38,6 +42,10 @@ class DeterministicVerifierFactoryService( whitelist = Whitelist.MINIMAL, visibleAnnotations = setOf( CordaSerializable::class.java, + CordaSerializationTransformEnumDefault::class.java, + CordaSerializationTransformEnumDefaults::class.java, + CordaSerializationTransformRename::class.java, + CordaSerializationTransformRenames::class.java, ConstructorForDeserialization::class.java, DeprecatedConstructorForDeserialization::class.java ), diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/SandboxSerializerFactoryFactory.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/SandboxSerializerFactoryFactory.kt index 95abdcd3fe..bc614f2f95 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/SandboxSerializerFactoryFactory.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/SandboxSerializerFactoryFactory.kt @@ -98,7 +98,8 @@ class SandboxSerializerFactoryFactory( localSerializerFactory = localSerializerFactory, classLoader = classLoader, mustPreserveDataWhenEvolving = context.preventDataLoss, - primitiveTypes = primitiveTypes + primitiveTypes = primitiveTypes, + baseTypes = localTypes ) val remoteSerializerFactory = DefaultRemoteSerializerFactory( diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt index fdf18afe99..6b73fa6e61 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt @@ -61,8 +61,9 @@ fun createSandboxSerializationEnv( @Suppress("unchecked_cast") val isEnumPredicate = predicateFactory.apply(CheckEnum::class.java) as Predicate> @Suppress("unchecked_cast") - val enumConstants = taskFactory.apply(DescribeEnum::class.java) - .andThen(taskFactory.apply(GetEnumNames::class.java)) + val enumConstants = taskFactory.apply(DescribeEnum::class.java) as Function, Array> + @Suppress("unchecked_cast") + val enumConstantNames = enumConstants.andThen(taskFactory.apply(GetEnumNames::class.java)) .andThen { (it as Array).map(Any::toString) } as Function, List> val sandboxLocalTypes = BaseLocalTypes( @@ -72,7 +73,8 @@ fun createSandboxSerializationEnv( mapClass = classLoader.toSandboxClass(Map::class.java), stringClass = classLoader.toSandboxClass(String::class.java), isEnum = isEnumPredicate, - enumConstants = enumConstants + enumConstants = enumConstants, + enumConstantNames = enumConstantNames ) val schemeBuilder = SandboxSerializationSchemeBuilder( classLoader = classLoader, diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeEnumWithEvolutionTest.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeEnumWithEvolutionTest.kt new file mode 100644 index 0000000000..32fea60fd0 --- /dev/null +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeEnumWithEvolutionTest.kt @@ -0,0 +1,155 @@ +package net.corda.serialization.djvm + +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.CordaSerializationTransformEnumDefault +import net.corda.core.serialization.CordaSerializationTransformEnumDefaults +import net.corda.core.serialization.CordaSerializationTransformRename +import net.corda.core.serialization.CordaSerializationTransformRenames +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.internal._contextSerializationEnv +import net.corda.core.serialization.serialize +import net.corda.serialization.djvm.EvolvedEnum.ONE +import net.corda.serialization.djvm.EvolvedEnum.TWO +import net.corda.serialization.djvm.EvolvedEnum.THREE +import net.corda.serialization.djvm.EvolvedEnum.FOUR +import net.corda.serialization.djvm.OriginalEnum.One +import net.corda.serialization.djvm.OriginalEnum.Two +import net.corda.serialization.djvm.SandboxType.KOTLIN +import net.corda.serialization.internal.amqp.CompositeType +import net.corda.serialization.internal.amqp.DeserializationInput +import net.corda.serialization.internal.amqp.RestrictedType +import net.corda.serialization.internal.amqp.Transform +import net.corda.serialization.internal.amqp.TransformTypes +import net.corda.serialization.internal.amqp.TypeNotation +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import java.util.EnumMap +import java.util.function.Function +import java.util.stream.Stream + +@ExtendWith(LocalSerialization::class) +class DeserializeEnumWithEvolutionTest : TestBase(KOTLIN) { + class EvolutionArgumentProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream { + return Stream.of( + Arguments.of(ONE, One), + Arguments.of(TWO, Two), + Arguments.of(THREE, One), + Arguments.of(FOUR, Two) + ) + } + } + + private fun String.devolve() = replace("Evolved", "Original") + + private fun devolveType(type: TypeNotation): TypeNotation { + return when (type) { + is CompositeType -> type.copy( + name = type.name.devolve(), + fields = type.fields.map { it.copy(type = it.type.devolve()) } + ) + is RestrictedType -> type.copy(name = type.name.devolve()) + else -> type + } + } + + private fun SerializedBytes<*>.devolve(context: SerializationContext): SerializedBytes { + val envelope = DeserializationInput.getEnvelope(this, context.encodingWhitelist).apply { + val schemaTypes = schema.types.map(::devolveType) + with(schema.types as MutableList) { + clear() + addAll(schemaTypes) + } + + val transforms = transformsSchema.types.asSequence().associateTo(LinkedHashMap()) { + it.key.devolve() to it.value + } + with(transformsSchema.types as MutableMap>>) { + clear() + putAll(transforms) + } + } + return SerializedBytes(envelope.write()) + } + + @ParameterizedTest + @ArgumentsSource(EvolutionArgumentProvider::class) + fun `test deserialising evolved enum`(value: EvolvedEnum, expected: OriginalEnum) { + val context = (_contextSerializationEnv.get() ?: fail("No serialization environment!")).p2pContext + + val evolvedData = value.serialize() + val originalData = evolvedData.devolve(context) + + sandbox { + _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) + val sandboxOriginal = originalData.deserializeFor(classLoader) + assertEquals("sandbox." + OriginalEnum::class.java.name, sandboxOriginal::class.java.name) + assertEquals(expected.toString(), sandboxOriginal.toString()) + } + } + + @ParameterizedTest + @ArgumentsSource(EvolutionArgumentProvider::class) + fun `test deserialising data with evolved enum`(value: EvolvedEnum, expected: OriginalEnum) { + val context = (_contextSerializationEnv.get() ?: fail("No serialization environment!")).p2pContext + + val evolvedData = EvolvedData(value).serialize() + val originalData = evolvedData.devolve(context) + + sandbox { + _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) + val sandboxOriginal = originalData.deserializeFor(classLoader) + + val taskFactory = classLoader.createRawTaskFactory() + val result = taskFactory.compose(classLoader.createSandboxFunction()) + .apply(ShowOriginalData::class.java) + .apply(sandboxOriginal) ?: fail("Result cannot be null") + assertThat(result.toString()) + .isEqualTo(ShowOriginalData().apply(OriginalData(expected))) + } + } + + class ShowOriginalData : Function { + override fun apply(input: OriginalData): String { + return with(input) { + "Name='${value.name}', Ordinal='${value.ordinal}'" + } + } + } +} + +@CordaSerializable +enum class OriginalEnum { + One, + Two +} + +@CordaSerializable +data class OriginalData(val value: OriginalEnum) + +@CordaSerializable +@CordaSerializationTransformRenames( + CordaSerializationTransformRename(from = "One", to = "ONE"), + CordaSerializationTransformRename(from = "Two", to = "TWO") +) +@CordaSerializationTransformEnumDefaults( + CordaSerializationTransformEnumDefault(new = "THREE", old = "One"), + CordaSerializationTransformEnumDefault(new = "FOUR", old = "Two") +) +enum class EvolvedEnum { + ONE, + TWO, + THREE, + FOUR +} + +@CordaSerializable +data class EvolvedData(val value: EvolvedEnum) \ No newline at end of file diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/SafeDeserialisationTest.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/SafeDeserialisationTest.kt index 9b55ff4f43..1552279f45 100644 --- a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/SafeDeserialisationTest.kt +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/SafeDeserialisationTest.kt @@ -4,22 +4,14 @@ import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.serialization.serialize import net.corda.serialization.djvm.SandboxType.KOTLIN -import net.corda.serialization.internal.SectionId import net.corda.serialization.internal.amqp.CompositeType import net.corda.serialization.internal.amqp.DeserializationInput -import net.corda.serialization.internal.amqp.Envelope import net.corda.serialization.internal.amqp.TypeNotation -import net.corda.serialization.internal.amqp.alsoAsByteBuffer -import net.corda.serialization.internal.amqp.amqpMagic -import net.corda.serialization.internal.amqp.withDescribed -import net.corda.serialization.internal.amqp.withList -import org.apache.qpid.proton.codec.Data import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.fail -import java.io.ByteArrayOutputStream import java.util.function.Function @ExtendWith(LocalSerialization::class) @@ -37,12 +29,8 @@ class SafeDeserialisationTest : TestBase(KOTLIN) { val innocentData = innocent.serialize() val envelope = DeserializationInput.getEnvelope(innocentData, context.encodingWhitelist).apply { val innocentType = schema.types[0] as CompositeType - (schema.types as MutableList)[0] = CompositeType( - name = innocentType.name.replace("Innocent", "VeryEvil"), - label = innocentType.label, - provides = innocentType.provides, - descriptor = innocentType.descriptor, - fields = innocentType.fields + (schema.types as MutableList)[0] = innocentType.copy( + name = innocentType.name.replace("Innocent", "VeryEvil") ) } val evilData = SerializedBytes(envelope.write()) @@ -68,23 +56,6 @@ class SafeDeserialisationTest : TestBase(KOTLIN) { } } - private fun Envelope.write(): ByteArray { - val data = Data.Factory.create() - data.withDescribed(Envelope.DESCRIPTOR_OBJECT) { - withList { - putObject(obj) - putObject(schema) - putObject(transformsSchema) - } - } - return ByteArrayOutputStream().use { - amqpMagic.writeTo(it) - SectionId.DATA_AND_STOP.writeTo(it) - it.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode) - it.toByteArray() - } - } - class ShowInnocentData : Function { override fun apply(data: InnocentData): String { return "${data::class.java.name}: ${data.message}, ${data.number}" diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestBase.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestBase.kt index 05edb2468d..898090e00a 100644 --- a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestBase.kt +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestBase.kt @@ -2,6 +2,10 @@ package net.corda.serialization.djvm import net.corda.core.serialization.ConstructorForDeserialization import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.CordaSerializationTransformEnumDefault +import net.corda.core.serialization.CordaSerializationTransformEnumDefaults +import net.corda.core.serialization.CordaSerializationTransformRename +import net.corda.core.serialization.CordaSerializationTransformRenames import net.corda.core.serialization.DeprecatedConstructorForDeserialization import net.corda.djvm.SandboxConfiguration import net.corda.djvm.SandboxRuntimeContext @@ -51,6 +55,10 @@ abstract class TestBase(type: SandboxType) { whitelist = MINIMAL, visibleAnnotations = setOf( CordaSerializable::class.java, + CordaSerializationTransformEnumDefault::class.java, + CordaSerializationTransformEnumDefaults::class.java, + CordaSerializationTransformRename::class.java, + CordaSerializationTransformRenames::class.java, ConstructorForDeserialization::class.java, DeprecatedConstructorForDeserialization::class.java ), diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestHelpers.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestHelpers.kt new file mode 100644 index 0000000000..0d5a46d179 --- /dev/null +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/TestHelpers.kt @@ -0,0 +1,28 @@ +@file:JvmName("TestHelpers") +package net.corda.serialization.djvm + +import net.corda.serialization.internal.SectionId +import net.corda.serialization.internal.amqp.Envelope +import net.corda.serialization.internal.amqp.alsoAsByteBuffer +import net.corda.serialization.internal.amqp.amqpMagic +import net.corda.serialization.internal.amqp.withDescribed +import net.corda.serialization.internal.amqp.withList +import org.apache.qpid.proton.codec.Data +import java.io.ByteArrayOutputStream + +fun Envelope.write(): ByteArray { + val data = Data.Factory.create() + data.withDescribed(Envelope.DESCRIPTOR_OBJECT) { + withList { + putObject(obj) + putObject(schema) + putObject(transformsSchema) + } + } + return ByteArrayOutputStream().use { + amqpMagic.writeTo(it) + SectionId.DATA_AND_STOP.writeTo(it) + it.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode) + it.toByteArray() + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt index 52aa3449b0..77b57f9235 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt @@ -1,6 +1,7 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.SerializationContext +import net.corda.serialization.internal.model.BaseLocalTypes import org.apache.qpid.proton.codec.Data import java.lang.UnsupportedOperationException import java.lang.reflect.Type @@ -34,6 +35,7 @@ import java.lang.reflect.Type class EnumEvolutionSerializer( override val type: Type, factory: LocalSerializerFactory, + private val baseLocalTypes: BaseLocalTypes, private val conversions: Map, private val ordinals: Map) : AMQPSerializer { override val typeDescriptor = factory.createDescriptor(type) @@ -46,7 +48,7 @@ class EnumEvolutionSerializer( val converted = conversions[enumName] ?: throw AMQPNotSerializableException(type, "No rule to evolve enum constant $type::$enumName") val ordinal = ordinals[converted] ?: throw AMQPNotSerializableException(type, "Ordinal not found for enum value $type::$converted") - return type.asClass().enumConstants[ordinal] + return baseLocalTypes.enumConstants.apply(type.asClass())[ordinal] } override fun writeClassInfo(output: SerializationOutput) { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt index 11b57b7ae3..4d2b710bf0 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt @@ -40,7 +40,8 @@ class DefaultEvolutionSerializerFactory( private val localSerializerFactory: LocalSerializerFactory, private val classLoader: ClassLoader, private val mustPreserveDataWhenEvolving: Boolean, - override val primitiveTypes: Map, Class<*>> + override val primitiveTypes: Map, Class<*>>, + private val baseTypes: BaseLocalTypes ): EvolutionSerializerFactory { // Invert the "primitive -> boxed primitive" mapping. private val primitiveBoxedTypes: Map, Class<*>> @@ -172,7 +173,7 @@ class DefaultEvolutionSerializerFactory( if (constantsAreReordered(localOrdinals, convertedOrdinals)) throw EvolutionSerializationException(this, "Constants have been reordered, additions must be appended to the end") - return EnumEvolutionSerializer(localTypeInformation.observedType, localSerializerFactory, conversions, localOrdinals) + return EnumEvolutionSerializer(localTypeInformation.observedType, localSerializerFactory, baseTypes, conversions, localOrdinals) } private fun constantsAreReordered(localOrdinals: Map, convertedOrdinals: Map): Boolean = diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt index 4e7fbb466b..27650621d7 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt @@ -97,10 +97,8 @@ object SerializerFactoryBuilder { mustPreserveDataWhenEvolving: Boolean): SerializerFactory { val customSerializerRegistry = CachingCustomSerializerRegistry(descriptorBasedSerializerRegistry) - val localTypeModel = ConfigurableLocalTypeModel( - WhitelistBasedTypeModelConfiguration( - whitelist, - customSerializerRegistry)) + val typeModelConfiguration = WhitelistBasedTypeModelConfiguration(whitelist, customSerializerRegistry) + val localTypeModel = ConfigurableLocalTypeModel(typeModelConfiguration) val fingerPrinter = overrideFingerPrinter ?: TypeModellingFingerPrinter(customSerializerRegistry) @@ -124,7 +122,8 @@ object SerializerFactoryBuilder { localSerializerFactory, classCarpenter.classloader, mustPreserveDataWhenEvolving, - javaPrimitiveTypes + javaPrimitiveTypes, + typeModelConfiguration.baseTypes ) else NoEvolutionSerializerFactory val remoteSerializerFactory = DefaultRemoteSerializerFactory( diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt index 726eeba09e..09a3a09308 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt @@ -46,7 +46,7 @@ abstract class Transform : DescribedType { * descendants of this class */ override fun newInstance(obj: Any?): Transform { - val described = Transform.checkDescribed(obj) as List<*> + val described = checkDescribed(obj) as List<*> return when (described[0]) { EnumDefaultSchemaTransform.typeName -> EnumDefaultSchemaTransform.newInstance(described) RenameSchemaTransform.typeName -> RenameSchemaTransform.newInstance(described) @@ -195,18 +195,24 @@ object TransformsAnnotationProcessor { * Obtain all of the transforms applied for the given [Class]. */ fun getTransformsSchema(type: Class<*>): TransformsMap { - val result = TransformsMap(TransformTypes::class.java) - // We only have transforms for enums at present. - if (!type.isEnum) return result + return when { + // This only detects Enum classes that are outside the DJVM sandbox. + type.isEnum -> getEnumTransformsSchema(type) + // We only have transforms for enums at present. + else -> TransformsMap(TransformTypes::class.java) + } + } + + fun getEnumTransformsSchema(type: Class<*>): TransformsMap { + val result = TransformsMap(TransformTypes::class.java) supportedTransforms.forEach { supportedTransform -> val annotationContainer = type.getAnnotation(supportedTransform.type) ?: return@forEach result.processAnnotations( - type, - supportedTransform.enum, - supportedTransform.getAnnotations(annotationContainer)) + type, + supportedTransform.enum, + supportedTransform.getAnnotations(annotationContainer)) } - return result } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt index fe9dcbb357..2e907059aa 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt @@ -61,7 +61,8 @@ private val DEFAULT_BASE_TYPES = BaseLocalTypes( mapClass = Map::class.java, stringClass = String::class.java, isEnum = Predicate { clazz -> clazz.isEnum }, - enumConstants = Function { clazz -> + enumConstants = Function { clazz -> clazz.enumConstants }, + enumConstantNames = Function { clazz -> (clazz as Class>).enumConstants.map(Enum<*>::name) } ) \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt index add971b99a..95d2e594e0 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt @@ -115,13 +115,13 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, baseTypes.mapClass.isAssignableFrom(type) -> AMap(type, typeIdentifier, Unknown, Unknown) type === baseTypes.stringClass -> Atomic(type, typeIdentifier) type.kotlin.javaPrimitiveType != null -> Atomic(type, typeIdentifier) - baseTypes.isEnum.test(type) -> baseTypes.enumConstants.apply(type).let { enumConstants -> + baseTypes.isEnum.test(type) -> baseTypes.enumConstantNames.apply(type).let { enumConstantNames -> AnEnum( type, typeIdentifier, - enumConstants, + enumConstantNames, buildInterfaceInformation(type), - getEnumTransforms(type, enumConstants) + getEnumTransforms(type, enumConstantNames) ) } type.kotlinObjectInstance != null -> Singleton( @@ -145,7 +145,7 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, private fun getEnumTransforms(type: Class<*>, enumConstants: List): EnumTransforms { try { val constants = enumConstants.asSequence().mapIndexed { index, constant -> constant to index }.toMap() - return EnumTransforms.build(TransformsAnnotationProcessor.getTransformsSchema(type), constants) + return EnumTransforms.build(TransformsAnnotationProcessor.getEnumTransformsSchema(type), constants) } catch (e: InvalidEnumTransformsException) { throw NotSerializableDetailedException(type.name, e.message!!) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt index 6186a09dbf..7cfdfa3cfc 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt @@ -136,5 +136,6 @@ class BaseLocalTypes( val mapClass: Class<*>, val stringClass: Class<*>, val isEnum: Predicate>, - val enumConstants: Function, List> + val enumConstants: Function, Array>, + val enumConstantNames: Function, List> )