CORDA-3716: Fix SandboxEnumSerializer to handle enums that override toString(). (#6159)

* CORDA-3716: Fix SandboxEnumSerializer to handle enums that override toString().

* Remove more uses of Enum.toString() from the Corda serializer.

* Add test coverage for this case to standard enum serializer.

* Increase maxWaitTimeout in IRSDemoTest to 150 seconds.
This commit is contained in:
Chris Rankin 2020-04-19 23:29:08 +01:00 committed by GitHub
parent 45b43f116d
commit 6f437b5b09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 170 additions and 15 deletions

View File

@ -48,7 +48,7 @@ class IRSDemoTest {
private val rpcUsers = listOf(User("user", "password", setOf("ALL"))) private val rpcUsers = listOf(User("user", "password", setOf("ALL")))
private val currentDate: LocalDate = LocalDate.now() private val currentDate: LocalDate = LocalDate.now()
private val futureDate: LocalDate = currentDate.plusMonths(6) private val futureDate: LocalDate = currentDate.plusMonths(6)
private val maxWaitTime: Duration = 60.seconds private val maxWaitTime: Duration = 150.seconds
@Test(timeout=300_000) @Test(timeout=300_000)
fun `runs IRS demo`() { fun `runs IRS demo`() {

View File

@ -0,0 +1,9 @@
package net.corda.serialization.djvm.deserializers
import java.util.function.Function
class GetEnumNames : Function<Array<Enum<*>>, Array<String>> {
override fun apply(enumValues: Array<Enum<*>>): Array<String> {
return enumValues.map(Enum<*>::name).toTypedArray()
}
}

View File

@ -12,6 +12,7 @@ import net.corda.djvm.rewiring.createSandboxPredicate
import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.serialization.djvm.deserializers.CheckEnum import net.corda.serialization.djvm.deserializers.CheckEnum
import net.corda.serialization.djvm.deserializers.DescribeEnum import net.corda.serialization.djvm.deserializers.DescribeEnum
import net.corda.serialization.djvm.deserializers.GetEnumNames
import net.corda.serialization.djvm.serializers.PrimitiveSerializer import net.corda.serialization.djvm.serializers.PrimitiveSerializer
import net.corda.serialization.internal.GlobalTransientClassWhiteList import net.corda.serialization.internal.GlobalTransientClassWhiteList
import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationContextImpl
@ -60,7 +61,9 @@ fun createSandboxSerializationEnv(
@Suppress("unchecked_cast") @Suppress("unchecked_cast")
val isEnumPredicate = predicateFactory.apply(CheckEnum::class.java) as Predicate<Class<*>> val isEnumPredicate = predicateFactory.apply(CheckEnum::class.java) as Predicate<Class<*>>
@Suppress("unchecked_cast") @Suppress("unchecked_cast")
val enumConstants = taskFactory.apply(DescribeEnum::class.java) as Function<Class<*>, Array<out Any>> val enumConstants = taskFactory.apply(DescribeEnum::class.java)
.andThen(taskFactory.apply(GetEnumNames::class.java))
.andThen { (it as Array<out Any>).map(Any::toString) } as Function<Class<*>, List<String>>
val sandboxLocalTypes = BaseLocalTypes( val sandboxLocalTypes = BaseLocalTypes(
collectionClass = classLoader.toSandboxClass(Collection::class.java), collectionClass = classLoader.toSandboxClass(Collection::class.java),

View File

@ -4,6 +4,7 @@ import net.corda.core.serialization.SerializationContext
import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.serialization.djvm.deserializers.CheckEnum import net.corda.serialization.djvm.deserializers.CheckEnum
import net.corda.serialization.djvm.deserializers.DescribeEnum import net.corda.serialization.djvm.deserializers.DescribeEnum
import net.corda.serialization.djvm.deserializers.GetEnumNames
import net.corda.serialization.djvm.toSandboxAnyClass import net.corda.serialization.djvm.toSandboxAnyClass
import net.corda.serialization.internal.amqp.AMQPNotSerializableException import net.corda.serialization.internal.amqp.AMQPNotSerializableException
import net.corda.serialization.internal.amqp.AMQPSerializer import net.corda.serialization.internal.amqp.AMQPSerializer
@ -32,6 +33,10 @@ class SandboxEnumSerializer(
private val describeEnum: Function<Class<*>, Array<Any>> private val describeEnum: Function<Class<*>, Array<Any>>
= taskFactory.apply(DescribeEnum::class.java) as Function<Class<*>, Array<Any>> = taskFactory.apply(DescribeEnum::class.java) as Function<Class<*>, Array<Any>>
@Suppress("unchecked_cast") @Suppress("unchecked_cast")
private val getEnumNames: Function<Array<Any>, List<String>>
= (taskFactory.apply(GetEnumNames::class.java) as Function<Array<Any>, Array<Any>>)
.andThen { it.map(Any::toString) }
@Suppress("unchecked_cast")
private val isEnum: Predicate<Class<*>> private val isEnum: Predicate<Class<*>>
= predicateFactory.apply(CheckEnum::class.java) as Predicate<Class<*>> = predicateFactory.apply(CheckEnum::class.java) as Predicate<Class<*>>
@ -46,7 +51,8 @@ class SandboxEnumSerializer(
return null return null
} }
val members = describeEnum.apply(declaredType) val members = describeEnum.apply(declaredType)
return ConcreteEnumSerializer(declaredType, members, localFactory) val memberNames = getEnumNames.apply(members)
return ConcreteEnumSerializer(declaredType, members, memberNames, localFactory)
} }
override fun readObject( override fun readObject(
@ -65,6 +71,7 @@ class SandboxEnumSerializer(
private class ConcreteEnumSerializer( private class ConcreteEnumSerializer(
declaredType: Class<*>, declaredType: Class<*>,
private val members: Array<Any>, private val members: Array<Any>,
private val memberNames: List<String>,
factory: LocalSerializerFactory factory: LocalSerializerFactory
) : AMQPSerializer<Any> { ) : AMQPSerializer<Any> {
override val type: Class<*> = declaredType override val type: Class<*> = declaredType
@ -78,7 +85,7 @@ private class ConcreteEnumSerializer(
LocalTypeInformation.AnEnum( LocalTypeInformation.AnEnum(
declaredType, declaredType,
TypeIdentifier.forGenericType(declaredType), TypeIdentifier.forGenericType(declaredType),
members.map(Any::toString), memberNames,
emptyList(), emptyList(),
EnumTransforms.empty EnumTransforms.empty
) )
@ -92,7 +99,7 @@ private class ConcreteEnumSerializer(
val enumOrd = obj[1] as Int val enumOrd = obj[1] as Int
val fromOrd = members[enumOrd] val fromOrd = members[enumOrd]
if (enumName != fromOrd.toString()) { if (enumName != memberNames[enumOrd]) {
throw AMQPNotSerializableException( throw AMQPNotSerializableException(
type, type,
"Deserializing obj as enum $type with value $enumName.$enumOrd but ordinality has changed" "Deserializing obj as enum $type with value $enumName.$enumOrd but ordinality has changed"

View File

@ -0,0 +1,60 @@
package net.corda.serialization.djvm
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.serialization.serialize
import net.corda.serialization.djvm.SandboxType.KOTLIN
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.fail
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import java.util.function.Function
@ExtendWith(LocalSerialization::class)
class DeserializeCustomisedEnumTest : TestBase(KOTLIN) {
@ParameterizedTest
@EnumSource(UserRole::class)
fun `test deserialize enum with custom toString`(role: UserRole) {
val userEnumData = UserEnumData(role)
val data = userEnumData.serialize()
sandbox {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))
val sandboxData = data.deserializeFor(classLoader)
val taskFactory = classLoader.createRawTaskFactory()
val showUserEnumData = taskFactory.compose(classLoader.createSandboxFunction()).apply(ShowUserEnumData::class.java)
val result = showUserEnumData.apply(sandboxData) ?: fail("Result cannot be null")
assertEquals(ShowUserEnumData().apply(userEnumData), result.toString())
assertEquals("UserRole: name='${role.roleName}', ordinal='${role.ordinal}'", result.toString())
assertEquals(SANDBOX_STRING, result::class.java.name)
}
}
class ShowUserEnumData : Function<UserEnumData, String> {
override fun apply(input: UserEnumData): String {
return with(input) {
"UserRole: name='${role.roleName}', ordinal='${role.ordinal}'"
}
}
}
}
interface Role {
val roleName: String
}
@Suppress("unused")
@CordaSerializable
enum class UserRole(override val roleName: String) : Role {
CONTROLLER(roleName = "Controller"),
WORKER(roleName = "Worker");
override fun toString() = roleName
}
@CordaSerializable
data class UserEnumData(val role: UserRole)

View File

@ -38,13 +38,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) {
return classLoader.toSandboxClass(T::class.java) return classLoader.toSandboxClass(T::class.java)
} }
private inline fun <reified LOCAL: LocalTypeInformation> assertLocalType(type: Class<*>) { private inline fun <reified LOCAL: LocalTypeInformation> assertLocalType(type: Class<*>): LOCAL {
assertLocalType(LOCAL::class.java, type) return assertLocalType(LOCAL::class.java, type) as LOCAL
} }
private fun <LOCAL: LocalTypeInformation> assertLocalType(localType: Class<LOCAL>, type: Class<*>) { private fun <LOCAL: LocalTypeInformation> assertLocalType(localType: Class<LOCAL>, type: Class<*>): LocalTypeInformation {
val typeData = serializerFactory.getTypeInformation(type) val typeData = serializerFactory.getTypeInformation(type)
assertThat(typeData).isInstanceOf(localType) assertThat(typeData).isInstanceOf(localType)
return typeData
} }
@Test @Test
@ -174,6 +175,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) {
assertLocalType<AnEnum>(sandbox<ExampleEnum>(classLoader)) assertLocalType<AnEnum>(sandbox<ExampleEnum>(classLoader))
} }
@Test
fun testCustomEnum() = sandbox {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))
val anEnum = assertLocalType<AnEnum>(sandbox<CustomEnum>(classLoader))
assertThat(anEnum.members)
.containsExactlyElementsOf(CustomEnum::class.java.enumConstants.map(CustomEnum::name))
}
@Test @Test
fun testEnumSet() = sandbox { fun testEnumSet() = sandbox {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))
@ -188,4 +197,14 @@ class LocalTypeModelTest : TestBase(KOTLIN) {
_contextSerializationEnv.set(createSandboxSerializationEnv(classLoader)) _contextSerializationEnv.set(createSandboxSerializationEnv(classLoader))
assertLocalType<AMap>(sandbox<Map<*,*>>(classLoader)) assertLocalType<AMap>(sandbox<Map<*,*>>(classLoader))
} }
@Suppress("unused")
enum class CustomEnum {
ONE,
TWO;
override fun toString(): String {
return "[${name.toLowerCase()}]"
}
}
} }

View File

@ -14,11 +14,12 @@ class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: Local
override val typeDescriptor = factory.createDescriptor(type) override val typeDescriptor = factory.createDescriptor(type)
init { init {
@Suppress("unchecked_cast")
typeNotation = RestrictedType( typeNotation = RestrictedType(
AMQPTypeIdentifiers.nameForType(declaredType), AMQPTypeIdentifiers.nameForType(declaredType),
null, emptyList(), "list", Descriptor(typeDescriptor), null, emptyList(), "list", Descriptor(typeDescriptor),
declaredClass.enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map { (declaredClass as Class<out Enum<*>>).enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map {
Choice(it.first.toString(), it.second.toString()) Choice(it.first.name, it.second.toString())
}) })
} }

View File

@ -53,6 +53,7 @@ private val opaqueTypes = setOf(
Symbol::class.java Symbol::class.java
) )
@Suppress("unchecked_cast")
private val DEFAULT_BASE_TYPES = BaseLocalTypes( private val DEFAULT_BASE_TYPES = BaseLocalTypes(
collectionClass = Collection::class.java, collectionClass = Collection::class.java,
enumSetClass = EnumSet::class.java, enumSetClass = EnumSet::class.java,
@ -60,5 +61,7 @@ private val DEFAULT_BASE_TYPES = BaseLocalTypes(
mapClass = Map::class.java, mapClass = Map::class.java,
stringClass = String::class.java, stringClass = String::class.java,
isEnum = Predicate { clazz -> clazz.isEnum }, isEnum = Predicate { clazz -> clazz.isEnum },
enumConstants = Function { clazz -> clazz.enumConstants } enumConstants = Function { clazz ->
(clazz as Class<out Enum<*>>).enumConstants.map(Enum<*>::name)
}
) )

View File

@ -119,7 +119,7 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup,
AnEnum( AnEnum(
type, type,
typeIdentifier, typeIdentifier,
enumConstants.map(Any::toString), enumConstants,
buildInterfaceInformation(type), buildInterfaceInformation(type),
getEnumTransforms(type, enumConstants) getEnumTransforms(type, enumConstants)
) )
@ -142,9 +142,9 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup,
} }
} }
private fun getEnumTransforms(type: Class<*>, enumConstants: Array<out Any>): EnumTransforms { private fun getEnumTransforms(type: Class<*>, enumConstants: List<String>): EnumTransforms {
try { try {
val constants = enumConstants.asSequence().mapIndexed { index, constant -> constant.toString() to index }.toMap() val constants = enumConstants.asSequence().mapIndexed { index, constant -> constant to index }.toMap()
return EnumTransforms.build(TransformsAnnotationProcessor.getTransformsSchema(type), constants) return EnumTransforms.build(TransformsAnnotationProcessor.getTransformsSchema(type), constants)
} catch (e: InvalidEnumTransformsException) { } catch (e: InvalidEnumTransformsException) {
throw NotSerializableDetailedException(type.name, e.message!!) throw NotSerializableDetailedException(type.name, e.message!!)

View File

@ -136,5 +136,5 @@ class BaseLocalTypes(
val mapClass: Class<*>, val mapClass: Class<*>,
val stringClass: Class<*>, val stringClass: Class<*>,
val isEnum: Predicate<Class<*>>, val isEnum: Predicate<Class<*>>,
val enumConstants: Function<Class<*>, Array<out Any>> val enumConstants: Function<Class<*>, List<String>>
) )

View File

@ -3,6 +3,8 @@ package net.corda.serialization.internal.amqp
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.serialization.internal.EmptyWhitelist
import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput
import net.corda.serialization.internal.amqp.testutils.deserialize import net.corda.serialization.internal.amqp.testutils.deserialize
import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope
@ -11,6 +13,7 @@ import net.corda.serialization.internal.amqp.testutils.testDefaultFactoryNoEvolu
import net.corda.serialization.internal.amqp.testutils.testName import net.corda.serialization.internal.amqp.testutils.testName
import net.corda.serialization.internal.carpenter.ClassCarpenterImpl import net.corda.serialization.internal.carpenter.ClassCarpenterImpl
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Assert.assertNotSame
import org.junit.Test import org.junit.Test
import java.io.NotSerializableException import java.io.NotSerializableException
import java.time.DayOfWeek import java.time.DayOfWeek
@ -279,4 +282,32 @@ class EnumTests {
DeserializationInput(factory2).deserialize(bytes) DeserializationInput(factory2).deserialize(bytes)
}.isInstanceOf(NotSerializableException::class.java) }.isInstanceOf(NotSerializableException::class.java)
} }
@Test(timeout = 300_000)
fun deserializeCustomisedEnum() {
val input = CustomEnumWrapper(CustomEnum.ONE)
val factory1 = SerializerFactoryBuilder.build(EmptyWhitelist, ClassLoader.getSystemClassLoader())
val serialized = TestSerializationOutput(VERBOSE, factory1).serialize(input)
val factory2 = SerializerFactoryBuilder.build(EmptyWhitelist, ClassLoader.getSystemClassLoader())
val output = DeserializationInput(factory2).deserialize(serialized)
assertEquals(input, output)
assertNotSame("Deserialized object should be brand new.", input, output)
}
@Suppress("unused")
@CordaSerializable
enum class CustomEnum {
ONE,
TWO,
THREE;
override fun toString(): String {
return "[${name.toLowerCase()}]"
}
}
@CordaSerializable
data class CustomEnumWrapper(val data: CustomEnum)
} }

View File

@ -4,9 +4,11 @@ import com.google.common.reflect.TypeToken
import net.corda.core.serialization.SerializableCalculatedProperty import net.corda.core.serialization.SerializableCalculatedProperty
import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.amqp.* import net.corda.serialization.internal.amqp.*
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue import org.junit.Assert.assertTrue
import org.junit.Test import org.junit.Test
import org.junit.jupiter.api.fail
import java.lang.reflect.Type import java.lang.reflect.Type
import java.util.* import java.util.*
@ -206,6 +208,26 @@ class LocalTypeModelTests {
} }
} }
@Suppress("unused")
enum class CustomEnum {
ONE,
TWO;
override fun toString(): String {
return "[${name.toLowerCase()}]"
}
}
@Test(timeout = 300_000)
fun `test type information for customised enum`() {
modelWithoutOpacity.inspect(typeOf<CustomEnum>()).let { typeInformation ->
val anEnum = typeInformation as? LocalTypeInformation.AnEnum ?: fail("Not AnEnum!")
assertThat(anEnum.members).containsExactlyElementsOf(
CustomEnum::class.java.enumConstants.map(CustomEnum::name)
)
}
}
private inline fun <reified T> assertInformation(expected: String) { private inline fun <reified T> assertInformation(expected: String) {
assertEquals(expected.trimIndent(), model.inspect(typeOf<T>()).prettyPrint()) assertEquals(expected.trimIndent(), model.inspect(typeOf<T>()).prettyPrint())
} }