From 6f437b5b09473f5a30f9b5cdad03dc530b568482 Mon Sep 17 00:00:00 2001 From: Chris Rankin Date: Sun, 19 Apr 2020 23:29:08 +0100 Subject: [PATCH] CORDA-3716: Fix SandboxEnumSerializer to handle enums that override toString(). (#6159) * CORDA-3716: Fix SandboxEnumSerializer to handle enums that override toString(). * Remove more uses of Enum.toString() from the Corda serializer. * Add test coverage for this case to standard enum serializer. * Increase maxWaitTimeout in IRSDemoTest to 150 seconds. --- .../kotlin/net/corda/irs/IRSDemoTest.kt | 2 +- .../djvm/deserializers/GetEnumNames.kt | 9 +++ .../corda/serialization/djvm/Serialization.kt | 5 +- .../djvm/serializers/SandboxEnumSerializer.kt | 13 +++- .../djvm/DeserializeCustomisedEnumTest.kt | 60 +++++++++++++++++++ .../serialization/djvm/LocalTypeModelTest.kt | 25 +++++++- .../internal/amqp/EnumSerializer.kt | 5 +- .../WhitelistBasedTypeModelConfiguration.kt | 5 +- .../model/LocalTypeInformationBuilder.kt | 6 +- .../internal/model/LocalTypeModel.kt | 2 +- .../serialization/internal/amqp/EnumTests.kt | 31 ++++++++++ .../internal/model/LocalTypeModelTests.kt | 22 +++++++ 12 files changed, 170 insertions(+), 15 deletions(-) create mode 100644 serialization-djvm/deserializers/src/main/kotlin/net/corda/serialization/djvm/deserializers/GetEnumNames.kt create mode 100644 serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeCustomisedEnumTest.kt diff --git a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt index ec95489f1d..fd1a4917e1 100644 --- a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt +++ b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt @@ -48,7 +48,7 @@ class IRSDemoTest { private val rpcUsers = listOf(User("user", "password", setOf("ALL"))) private val currentDate: LocalDate = LocalDate.now() private val futureDate: LocalDate = currentDate.plusMonths(6) - private val maxWaitTime: Duration = 60.seconds + private val maxWaitTime: Duration = 150.seconds @Test(timeout=300_000) fun `runs IRS demo`() { diff --git a/serialization-djvm/deserializers/src/main/kotlin/net/corda/serialization/djvm/deserializers/GetEnumNames.kt b/serialization-djvm/deserializers/src/main/kotlin/net/corda/serialization/djvm/deserializers/GetEnumNames.kt new file mode 100644 index 0000000000..5e60f530b4 --- /dev/null +++ b/serialization-djvm/deserializers/src/main/kotlin/net/corda/serialization/djvm/deserializers/GetEnumNames.kt @@ -0,0 +1,9 @@ +package net.corda.serialization.djvm.deserializers + +import java.util.function.Function + +class GetEnumNames : Function>, Array> { + override fun apply(enumValues: Array>): Array { + return enumValues.map(Enum<*>::name).toTypedArray() + } +} diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt index 0c431e669b..fdf18afe99 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/Serialization.kt @@ -12,6 +12,7 @@ import net.corda.djvm.rewiring.createSandboxPredicate import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.CheckEnum import net.corda.serialization.djvm.deserializers.DescribeEnum +import net.corda.serialization.djvm.deserializers.GetEnumNames import net.corda.serialization.djvm.serializers.PrimitiveSerializer import net.corda.serialization.internal.GlobalTransientClassWhiteList import net.corda.serialization.internal.SerializationContextImpl @@ -60,7 +61,9 @@ fun createSandboxSerializationEnv( @Suppress("unchecked_cast") val isEnumPredicate = predicateFactory.apply(CheckEnum::class.java) as Predicate> @Suppress("unchecked_cast") - val enumConstants = taskFactory.apply(DescribeEnum::class.java) as Function, Array> + val enumConstants = taskFactory.apply(DescribeEnum::class.java) + .andThen(taskFactory.apply(GetEnumNames::class.java)) + .andThen { (it as Array).map(Any::toString) } as Function, List> val sandboxLocalTypes = BaseLocalTypes( collectionClass = classLoader.toSandboxClass(Collection::class.java), diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxEnumSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxEnumSerializer.kt index 73b4421e19..a052e799da 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxEnumSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxEnumSerializer.kt @@ -4,6 +4,7 @@ import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.CheckEnum import net.corda.serialization.djvm.deserializers.DescribeEnum +import net.corda.serialization.djvm.deserializers.GetEnumNames import net.corda.serialization.djvm.toSandboxAnyClass import net.corda.serialization.internal.amqp.AMQPNotSerializableException import net.corda.serialization.internal.amqp.AMQPSerializer @@ -32,6 +33,10 @@ class SandboxEnumSerializer( private val describeEnum: Function, Array> = taskFactory.apply(DescribeEnum::class.java) as Function, Array> @Suppress("unchecked_cast") + private val getEnumNames: Function, List> + = (taskFactory.apply(GetEnumNames::class.java) as Function, Array>) + .andThen { it.map(Any::toString) } + @Suppress("unchecked_cast") private val isEnum: Predicate> = predicateFactory.apply(CheckEnum::class.java) as Predicate> @@ -46,7 +51,8 @@ class SandboxEnumSerializer( return null } val members = describeEnum.apply(declaredType) - return ConcreteEnumSerializer(declaredType, members, localFactory) + val memberNames = getEnumNames.apply(members) + return ConcreteEnumSerializer(declaredType, members, memberNames, localFactory) } override fun readObject( @@ -65,6 +71,7 @@ class SandboxEnumSerializer( private class ConcreteEnumSerializer( declaredType: Class<*>, private val members: Array, + private val memberNames: List, factory: LocalSerializerFactory ) : AMQPSerializer { override val type: Class<*> = declaredType @@ -78,7 +85,7 @@ private class ConcreteEnumSerializer( LocalTypeInformation.AnEnum( declaredType, TypeIdentifier.forGenericType(declaredType), - members.map(Any::toString), + memberNames, emptyList(), EnumTransforms.empty ) @@ -92,7 +99,7 @@ private class ConcreteEnumSerializer( val enumOrd = obj[1] as Int val fromOrd = members[enumOrd] - if (enumName != fromOrd.toString()) { + if (enumName != memberNames[enumOrd]) { throw AMQPNotSerializableException( type, "Deserializing obj as enum $type with value $enumName.$enumOrd but ordinality has changed" diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeCustomisedEnumTest.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeCustomisedEnumTest.kt new file mode 100644 index 0000000000..657ab25f34 --- /dev/null +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/DeserializeCustomisedEnumTest.kt @@ -0,0 +1,60 @@ +package net.corda.serialization.djvm + +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.internal._contextSerializationEnv +import net.corda.core.serialization.serialize +import net.corda.serialization.djvm.SandboxType.KOTLIN +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.EnumSource +import java.util.function.Function + +@ExtendWith(LocalSerialization::class) +class DeserializeCustomisedEnumTest : TestBase(KOTLIN) { + @ParameterizedTest + @EnumSource(UserRole::class) + fun `test deserialize enum with custom toString`(role: UserRole) { + val userEnumData = UserEnumData(role) + val data = userEnumData.serialize() + + sandbox { + _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) + + val sandboxData = data.deserializeFor(classLoader) + + val taskFactory = classLoader.createRawTaskFactory() + val showUserEnumData = taskFactory.compose(classLoader.createSandboxFunction()).apply(ShowUserEnumData::class.java) + val result = showUserEnumData.apply(sandboxData) ?: fail("Result cannot be null") + + assertEquals(ShowUserEnumData().apply(userEnumData), result.toString()) + assertEquals("UserRole: name='${role.roleName}', ordinal='${role.ordinal}'", result.toString()) + assertEquals(SANDBOX_STRING, result::class.java.name) + } + } + + class ShowUserEnumData : Function { + override fun apply(input: UserEnumData): String { + return with(input) { + "UserRole: name='${role.roleName}', ordinal='${role.ordinal}'" + } + } + } +} + +interface Role { + val roleName: String +} + +@Suppress("unused") +@CordaSerializable +enum class UserRole(override val roleName: String) : Role { + CONTROLLER(roleName = "Controller"), + WORKER(roleName = "Worker"); + + override fun toString() = roleName +} + +@CordaSerializable +data class UserEnumData(val role: UserRole) diff --git a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/LocalTypeModelTest.kt b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/LocalTypeModelTest.kt index 7ca5a3a61b..90b5d0813f 100644 --- a/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/LocalTypeModelTest.kt +++ b/serialization-djvm/src/test/kotlin/net/corda/serialization/djvm/LocalTypeModelTest.kt @@ -38,13 +38,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) { return classLoader.toSandboxClass(T::class.java) } - private inline fun assertLocalType(type: Class<*>) { - assertLocalType(LOCAL::class.java, type) + private inline fun assertLocalType(type: Class<*>): LOCAL { + return assertLocalType(LOCAL::class.java, type) as LOCAL } - private fun assertLocalType(localType: Class, type: Class<*>) { + private fun assertLocalType(localType: Class, type: Class<*>): LocalTypeInformation { val typeData = serializerFactory.getTypeInformation(type) assertThat(typeData).isInstanceOf(localType) + return typeData } @Test @@ -174,6 +175,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) { assertLocalType(sandbox(classLoader)) } + @Test + fun testCustomEnum() = sandbox { + _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) + val anEnum = assertLocalType(sandbox(classLoader)) + assertThat(anEnum.members) + .containsExactlyElementsOf(CustomEnum::class.java.enumConstants.map(CustomEnum::name)) + } + @Test fun testEnumSet() = sandbox { _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) @@ -188,4 +197,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) { _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) assertLocalType(sandbox>(classLoader)) } + + @Suppress("unused") + enum class CustomEnum { + ONE, + TWO; + + override fun toString(): String { + return "[${name.toLowerCase()}]" + } + } } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt index da8b922649..f4e9647486 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt @@ -14,11 +14,12 @@ class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: Local override val typeDescriptor = factory.createDescriptor(type) init { + @Suppress("unchecked_cast") typeNotation = RestrictedType( AMQPTypeIdentifiers.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), - declaredClass.enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map { - Choice(it.first.toString(), it.second.toString()) + (declaredClass as Class>).enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map { + Choice(it.first.name, it.second.toString()) }) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt index 3e9aea7aa0..fe9dcbb357 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/WhitelistBasedTypeModelConfiguration.kt @@ -53,6 +53,7 @@ private val opaqueTypes = setOf( Symbol::class.java ) +@Suppress("unchecked_cast") private val DEFAULT_BASE_TYPES = BaseLocalTypes( collectionClass = Collection::class.java, enumSetClass = EnumSet::class.java, @@ -60,5 +61,7 @@ private val DEFAULT_BASE_TYPES = BaseLocalTypes( mapClass = Map::class.java, stringClass = String::class.java, isEnum = Predicate { clazz -> clazz.isEnum }, - enumConstants = Function { clazz -> clazz.enumConstants } + enumConstants = Function { clazz -> + (clazz as Class>).enumConstants.map(Enum<*>::name) + } ) \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt index 700cbf0ea8..add971b99a 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt @@ -119,7 +119,7 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, AnEnum( type, typeIdentifier, - enumConstants.map(Any::toString), + enumConstants, buildInterfaceInformation(type), getEnumTransforms(type, enumConstants) ) @@ -142,9 +142,9 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, } } - private fun getEnumTransforms(type: Class<*>, enumConstants: Array): EnumTransforms { + private fun getEnumTransforms(type: Class<*>, enumConstants: List): EnumTransforms { try { - val constants = enumConstants.asSequence().mapIndexed { index, constant -> constant.toString() to index }.toMap() + val constants = enumConstants.asSequence().mapIndexed { index, constant -> constant to index }.toMap() return EnumTransforms.build(TransformsAnnotationProcessor.getTransformsSchema(type), constants) } catch (e: InvalidEnumTransformsException) { throw NotSerializableDetailedException(type.name, e.message!!) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt index 19ac1e018b..6186a09dbf 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeModel.kt @@ -136,5 +136,5 @@ class BaseLocalTypes( val mapClass: Class<*>, val stringClass: Class<*>, val isEnum: Predicate>, - val enumConstants: Function, Array> + val enumConstants: Function, List> ) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumTests.kt index c2cda2c1cf..b406283d9f 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumTests.kt @@ -3,6 +3,8 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.serialization.internal.EmptyWhitelist import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput import net.corda.serialization.internal.amqp.testutils.deserialize import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope @@ -11,6 +13,7 @@ import net.corda.serialization.internal.amqp.testutils.testDefaultFactoryNoEvolu import net.corda.serialization.internal.amqp.testutils.testName import net.corda.serialization.internal.carpenter.ClassCarpenterImpl import org.assertj.core.api.Assertions +import org.junit.Assert.assertNotSame import org.junit.Test import java.io.NotSerializableException import java.time.DayOfWeek @@ -279,4 +282,32 @@ class EnumTests { DeserializationInput(factory2).deserialize(bytes) }.isInstanceOf(NotSerializableException::class.java) } + + @Test(timeout = 300_000) + fun deserializeCustomisedEnum() { + val input = CustomEnumWrapper(CustomEnum.ONE) + val factory1 = SerializerFactoryBuilder.build(EmptyWhitelist, ClassLoader.getSystemClassLoader()) + val serialized = TestSerializationOutput(VERBOSE, factory1).serialize(input) + + val factory2 = SerializerFactoryBuilder.build(EmptyWhitelist, ClassLoader.getSystemClassLoader()) + val output = DeserializationInput(factory2).deserialize(serialized) + + assertEquals(input, output) + assertNotSame("Deserialized object should be brand new.", input, output) + } + + @Suppress("unused") + @CordaSerializable + enum class CustomEnum { + ONE, + TWO, + THREE; + + override fun toString(): String { + return "[${name.toLowerCase()}]" + } + } + + @CordaSerializable + data class CustomEnumWrapper(val data: CustomEnum) } \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt index a1ecbe6799..8d74f3eef8 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt @@ -4,9 +4,11 @@ import com.google.common.reflect.TypeToken import net.corda.core.serialization.SerializableCalculatedProperty import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.amqp.* +import org.assertj.core.api.Assertions.assertThat import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Test +import org.junit.jupiter.api.fail import java.lang.reflect.Type import java.util.* @@ -206,6 +208,26 @@ class LocalTypeModelTests { } } + @Suppress("unused") + enum class CustomEnum { + ONE, + TWO; + + override fun toString(): String { + return "[${name.toLowerCase()}]" + } + } + + @Test(timeout = 300_000) + fun `test type information for customised enum`() { + modelWithoutOpacity.inspect(typeOf()).let { typeInformation -> + val anEnum = typeInformation as? LocalTypeInformation.AnEnum ?: fail("Not AnEnum!") + assertThat(anEnum.members).containsExactlyElementsOf( + CustomEnum::class.java.enumConstants.map(CustomEnum::name) + ) + } + } + private inline fun assertInformation(expected: String) { assertEquals(expected.trimIndent(), model.inspect(typeOf()).prettyPrint()) }