From c48a37a08051c16233f69eb1d80b79d614feab77 Mon Sep 17 00:00:00 2001 From: Katelyn Baker Date: Fri, 1 Sep 2017 15:07:08 +0100 Subject: [PATCH] CORDA-539 - Add enum support to the carpenter If the serializer is going to support enumerated types then the class carpenter also has to Refactor the Carpenter schema and fields to add an enum type, add code in the carpenter to generate enum's and of course add tests --- docs/source/changelog.rst | 2 + .../serialization/carpenter/ClassCarpenter.kt | 168 +++++++++++++---- .../serialization/carpenter/Schema.kt | 171 +++++++----------- .../serialization/carpenter/SchemaFields.kt | 138 ++++++++++++++ .../carpenter/ClassCarpenterTest.kt | 6 +- .../carpenter/ClassCarpenterTestUtils.kt | 4 +- .../serialization/carpenter/EnumClassTests.kt | 105 +++++++++++ 7 files changed, 449 insertions(+), 145 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index dc44deb053..c2bd2bfe89 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,8 @@ from the previous milestone release. UNRELEASED ---------- +* Adding enum support to the class carpenter + * ``ContractState::contract`` has been moved ``TransactionState::contract`` and it's type has changed to ``String`` in order to support dynamic classloading of contract and contract constraints. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index f62f5ca1b1..3a01d4f9c3 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -25,6 +25,16 @@ class CarpenterClassLoader (parentClassLoader: ClassLoader = Thread.currentThrea fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) } +/** + * Which version of the java runtime are we constructing objects against + */ +private const val TARGET_VERSION = V1_8 + +private const val jlEnum = "java/lang/Enum" +private const val jlString = "java/lang/String" +private const val jlObject = "java/lang/Object" +private const val jlClass = "java/lang/Class" + /** * A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader. * The generated classes have getters, a toString method and implement a simple property access interface. The @@ -107,6 +117,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader when (it) { is InterfaceSchema -> generateInterface(it) is ClassSchema -> generateClass(it) + is EnumSchema -> generateEnum(it) } } @@ -115,41 +126,55 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader return _loaded[schema.name]!! } + private fun generateEnum(enumSchema: Schema): Class<*> { + return generate(enumSchema) { cw, schema -> + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_FINAL + ACC_SUPER + ACC_ENUM, schema.jvmName, + "L$jlEnum;", jlEnum, null) + generateFields(schema) + generateStaticEnumConstructor(schema) + generateEnumConstructor() + generateEnumValues(schema) + generateEnumValueOf(schema) + }.visitEnd() + } + } + private fun generateInterface(interfaceSchema: Schema): Class<*> { return generate(interfaceSchema) { cw, schema -> val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray() - with(cw) { - visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces) + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, + jlObject, interfaces) visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() generateAbstractGetters(schema) - - visitEnd() - } + }.visitEnd() } } private fun generateClass(classSchema: Schema): Class<*> { return generate(classSchema) { cw, schema -> - val superName = schema.superclass?.jvmName ?: "java/lang/Object" + val superName = schema.superclass?.jvmName ?: jlObject val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() - if (SimpleFieldAccess::class.java !in schema.interfaces) interfaces.add(SimpleFieldAccess::class.java.name.jvm) + if (SimpleFieldAccess::class.java !in schema.interfaces) { + interfaces.add(SimpleFieldAccess::class.java.name.jvm) + } - with(cw) { - visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray()) + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, + interfaces.toTypedArray()) visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() generateFields(schema) - generateConstructor(schema) + generateClassConstructor(schema) generateGetters(schema) if (schema.superclass == null) generateGetMethod() // From SimplePropertyAccess generateToString(schema) - - visitEnd() - } + }.visitEnd() } } @@ -165,25 +190,26 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateFields(schema: Schema) { - schema.fields.forEach { it.value.generateField(this) } + schema.generateFields(this) } private fun ClassWriter.generateToString(schema: Schema) { val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" - with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", null, null)) { + with(visitMethod(ACC_PUBLIC, "toString", "()L$jlString;", null, null)) { visitCode() // com.google.common.base.MoreObjects.toStringHelper("TypeName") visitLdcInsn(schema.name.split('.').last()) - visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", "(Ljava/lang/String;)L$toStringHelper;", false) + visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", + "(L$jlString;)L$toStringHelper;", false) // Call the add() methods. for ((name, field) in schema.fieldsIncludingSuperclasses().entries) { visitLdcInsn(name) visitVarInsn(ALOAD, 0) // this visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) - visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;${field.type})L$toStringHelper;", false) + visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(L$jlString;${field.type})L$toStringHelper;", false) } // call toString() on the builder and return. - visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()Ljava/lang/String;", false) + visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()L$jlString;", false) visitInsn(ARETURN) visitMaxs(0, 0) visitEnd() @@ -192,14 +218,14 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader private fun ClassWriter.generateGetMethod() { val ourJvmName = ClassCarpenter::class.java.name.jvm - with(visitMethod(ACC_PUBLIC, "get", "(Ljava/lang/String;)Ljava/lang/Object;", null, null)) { + with(visitMethod(ACC_PUBLIC, "get", "(L$jlString;)L$jlObject;", null, null)) { visitCode() visitVarInsn(ALOAD, 0) // Load 'this' visitVarInsn(ALOAD, 1) // Load the name argument // Using this generic helper method is slow, as it relies on reflection. A faster way would be // to use a tableswitch opcode, or just push back on the user and ask them to use actual reflection // or MethodHandles (super fast reflection) to access the object instead. - visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(L$jlObject;L$jlString;)L$jlObject;", false) visitInsn(ARETURN) visitMaxs(0, 0) visitEnd() @@ -207,8 +233,9 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateGetters(schema: Schema) { - for ((name, type) in schema.fields) { - with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null)) { + @Suppress("UNCHECKED_CAST") + for ((name, type) in (schema.fields as Map)) { + visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null).apply { type.addNullabilityAnnotation(this) visitCode() visitVarInsn(ALOAD, 0) // Load 'this' @@ -222,30 +249,97 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader else -> visitInsn(ARETURN) } visitMaxs(0, 0) - visitEnd() - } + }.visitEnd() } } private fun ClassWriter.generateAbstractGetters(schema: Schema) { - for ((name, field) in schema.fields) { + @Suppress("UNCHECKED_CAST") + for ((name, field) in (schema.fields as Map)) { val opcodes = ACC_ABSTRACT + ACC_PUBLIC - with(visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null)) { - // abstract method doesn't have any implementation so just end - visitEnd() - } + // abstract method doesn't have any implementation so just end + visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null).visitEnd() } } - private fun ClassWriter.generateConstructor(schema: Schema) { - with(visitMethod( + private fun ClassWriter.generateStaticEnumConstructor(schema: Schema) { + visitMethod(ACC_STATIC, "", "()V", null, null).apply { + visitCode() + visitIntInsn(BIPUSH, schema.fields.size) + visitTypeInsn(ANEWARRAY, schema.jvmName) + visitInsn(DUP) + + var idx = 0 + schema.fields.forEach { + visitInsn(DUP) + visitIntInsn(BIPUSH, idx) + visitTypeInsn(NEW, schema.jvmName) + visitInsn(DUP) + visitLdcInsn(it.key) + visitIntInsn(BIPUSH, idx++) + visitMethodInsn(INVOKESPECIAL, schema.jvmName, "", "(L$jlString;I)V", false) + visitInsn(DUP) + visitFieldInsn(PUTSTATIC, schema.jvmName, it.key, "L${schema.jvmName};") + visitInsn(AASTORE) + } + + visitFieldInsn(PUTSTATIC, schema.jvmName, "\$VALUES", schema.asArray) + visitInsn(RETURN) + + visitMaxs(0,0) + }.visitEnd() + } + + private fun ClassWriter.generateEnumValues(schema: Schema) { + visitMethod(ACC_PUBLIC + ACC_STATIC, "values", "()${schema.asArray}", null, null).apply { + visitCode() + visitFieldInsn(GETSTATIC, schema.jvmName, "\$VALUES", schema.asArray) + visitMethodInsn(INVOKEVIRTUAL, schema.asArray, "clone", "()L$jlObject;", false) + visitTypeInsn(CHECKCAST, schema.asArray) + visitInsn(ARETURN) + visitMaxs(0, 0) + }.visitEnd() + } + + private fun ClassWriter.generateEnumValueOf(schema: Schema) { + visitMethod(ACC_PUBLIC + ACC_STATIC, "valueOf", "(L$jlString;)L${schema.jvmName};", null, null).apply { + visitCode() + visitLdcInsn(Type.getType("L${schema.jvmName};")) + visitVarInsn(ALOAD, 0) + visitMethodInsn(INVOKESTATIC, jlEnum, "valueOf", "(L$jlClass;L$jlString;)L$jlEnum;", true) + visitTypeInsn(CHECKCAST, schema.jvmName) + visitInsn(ARETURN) + visitMaxs(0, 0) + }.visitEnd() + + } + + private fun ClassWriter.generateEnumConstructor() { + visitMethod(ACC_PROTECTED, "", "(L$jlString;I)V", "()V", null).apply { + visitParameter("\$enum\$name", ACC_SYNTHETIC) + visitParameter("\$enum\$ordinal", ACC_SYNTHETIC) + + visitCode() + + visitVarInsn(ALOAD, 0) // this + visitVarInsn(ALOAD, 1) + visitVarInsn(ILOAD, 2) + visitMethodInsn(INVOKESPECIAL, jlEnum, "", "(L$jlString;I)V", false) + visitInsn(RETURN) + + visitMaxs(0, 0) + }.visitEnd() + } + + private fun ClassWriter.generateClassConstructor(schema: Schema) { + visitMethod( ACC_PUBLIC, "", "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", null, - null)) - { + null).apply { var idx = 0 + schema.fields.values.forEach { it.visitParameter(this, idx++) } visitCode() @@ -255,7 +349,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visitVarInsn(ALOAD, 0) val sc = schema.superclass if (sc == null) { - visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) + visitMethodInsn(INVOKESPECIAL, jlObject, "", "()V", false) } else { var slot = 1 superclassFields.values.forEach { slot += load(slot, it) } @@ -265,7 +359,8 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader // Assign the fields from parameters. var slot = 1 + superclassFields.size - for ((name, field) in schema.fields.entries) { + @Suppress("UNCHECKED_CAST") + for ((name, field) in (schema.fields as Map)) { field.nullTest(this, slot) visitVarInsn(ALOAD, 0) // Load 'this' onto the stack @@ -274,8 +369,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } visitInsn(RETURN) visitMaxs(0, 0) - visitEnd() - } + }.visitEnd() } private fun MethodVisitor.load(slot: Int, type: Field): Int { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt index a0753bbfe4..06fae3c466 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt @@ -1,53 +1,104 @@ package net.corda.nodeapi.internal.serialization.carpenter -import jdk.internal.org.objectweb.asm.Opcodes.* +import kotlin.collections.LinkedHashMap import org.objectweb.asm.ClassWriter -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Type -import java.util.* +import org.objectweb.asm.Opcodes.* /** - * A Schema represents a desired class. + * A Schema is the representation of an object the Carpenter can contsruct + * + * Known Sub Classes + * - [ClassSchema] + * - [InterfaceSchema] + * - [EnumSchema] */ abstract class Schema( val name: String, - fields: Map, + var fields: Map, val superclass: Schema? = null, - val interfaces: List> = emptyList()) + val interfaces: List> = emptyList(), + updater : (String, Field) -> Unit) { - private fun Map.descriptors() = - LinkedHashMap(this.mapValues { it.value.descriptor }) + private fun Map.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor }) - /* Fix the order up front if the user didn't, inject the name into the field as it's - neater when iterating */ - val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) }) + init { + fields.forEach { updater (it.key, it.value) } + + // Fix the order up front if the user didn't, inject the name into the field as it's + // neater when iterating + fields = LinkedHashMap(fields) + } fun fieldsIncludingSuperclasses(): Map = (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) - fun descriptorsIncludingSuperclasses(): Map = + fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors() + abstract fun generateFields(cw: ClassWriter) + val jvmName: String get() = name.replace(".", "/") + + val asArray: String + get() = "[L$jvmName;" } +/** + * Represents a concrete object + */ class ClassSchema( name: String, fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() -) : Schema(name, fields, superclass, interfaces) +) : Schema(name, fields, superclass, interfaces, { name, field -> field.name = name }) { + override fun generateFields(cw: ClassWriter) { + cw.apply { fields.forEach { it.value.generateField(this) } } + } +} +/** + * Represents an interface. Carpented interfaces can be used within [ClassSchema]s + * if that class should be implementing that interface + */ class InterfaceSchema( name: String, fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() -) : Schema(name, fields, superclass, interfaces) +) : Schema(name, fields, superclass, interfaces, { name, field -> field.name = name }) { + override fun generateFields(cw: ClassWriter) { + cw.apply { fields.forEach { it.value.generateField(this) } } + } +} +/** + * Represents an enumerated type + */ +class EnumSchema( + name: String, + fields: Map +) : Schema(name, fields, null, emptyList(), { fieldName, field -> + (field as EnumField).name = fieldName + field.descriptor = "L${name.replace(".", "/")};" +}) { + override fun generateFields(cw: ClassWriter) { + with(cw) { + fields.forEach { it.value.generateField(this) } + + visitField(ACC_PRIVATE + ACC_FINAL + ACC_STATIC + ACC_SYNTHETIC, + "\$VALUES", asArray, null, null) + } + } +} + +/** + * Factory object used by the serialiser when build [Schema]s based + * on an AMQP schema + */ object CarpenterSchemaFactory { - fun newInstance ( + fun newInstance( name: String, fields: Map, superclass: Schema? = null, @@ -58,91 +109,3 @@ object CarpenterSchemaFactory { else ClassSchema (name, fields, superclass, interfaces) } -abstract class Field(val field: Class) { - companion object { - const val unsetName = "Unset" - } - - var name: String = unsetName - abstract val nullabilityAnnotation: String - - val descriptor: String - get() = Type.getDescriptor(this.field) - - val type: String - get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" - - fun generateField(cw: ClassWriter) { - val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null) - fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd() - fieldVisitor.visitEnd() - } - - fun addNullabilityAnnotation(mv: MethodVisitor) { - mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() - } - - fun visitParameter(mv: MethodVisitor, idx: Int) { - with(mv) { - visitParameter(name, 0) - if (!field.isPrimitive) { - visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() - } - } - } - - abstract fun copy(name: String, field: Class): Field - abstract fun nullTest(mv: MethodVisitor, slot: Int) -} - -class NonNullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" - - constructor(name: String, field: Class) : this(field) { - this.name = name - } - - override fun copy(name: String, field: Class) = NonNullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - - if (!field.isPrimitive) { - with(mv) { - visitVarInsn(ALOAD, 0) // load this - visitVarInsn(ALOAD, slot) // load parameter - visitLdcInsn("param \"$name\" cannot be null") - visitMethodInsn(INVOKESTATIC, - "java/util/Objects", - "requireNonNull", - "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) - visitInsn(POP) - } - } - } -} - -class NullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" - - constructor(name: String, field: Class) : this(field) { - if (field.isPrimitive) { - throw NullablePrimitiveException ( - "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") - } - - this.name = name - } - - override fun copy(name: String, field: Class) = NullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - } -} - -object FieldFactory { - fun newInstance (mandatory: Boolean, name: String, field: Class) = - if (mandatory) NonNullableField (name, field) else NullableField (name, field) - -} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt new file mode 100644 index 0000000000..0090f4dd3e --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt @@ -0,0 +1,138 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import jdk.internal.org.objectweb.asm.Opcodes.* +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Type +import java.beans.BeanDescriptor +import java.util.* + +abstract class Field(val field: Class) { + abstract var descriptor: String? + + companion object { + const val unsetName = "Unset" + } + + var name: String = unsetName + abstract val type: String + + abstract fun generateField(cw: ClassWriter) + abstract fun visitParameter(mv: MethodVisitor, idx: Int) +} + +/** + * Any field that can be a member of an object + * + * Known + * - [NullableField] + * - [NonNullableField] + */ +abstract class ClassField(field: Class) : Field(field) { + abstract val nullabilityAnnotation: String + abstract fun nullTest(mv: MethodVisitor, slot: Int) + + override var descriptor = Type.getDescriptor(this.field) + override val type: String get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" + + fun addNullabilityAnnotation(mv: MethodVisitor) { + mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() + } + + override fun generateField(cw: ClassWriter) { + cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null).visitAnnotation( + nullabilityAnnotation, true).visitEnd() + } + + override fun visitParameter(mv: MethodVisitor, idx: Int) { + with(mv) { + visitParameter(name, 0) + if (!field.isPrimitive) { + visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() + } + } + } +} + +/** + * A member of a constructed class that can be assigned to null, the + * mandatory type for primitives, but also any member that cannot be + * null + * + * maps to AMQP mandatory = true fields + */ +open class NonNullableField(field: Class) : ClassField(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" + + constructor(name: String, field: Class) : this(field) { + this.name = name + } + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + + if (!field.isPrimitive) { + with(mv) { + visitVarInsn(ALOAD, 0) // load this + visitVarInsn(ALOAD, slot) // load parameter + visitLdcInsn("param \"$name\" cannot be null") + visitMethodInsn(INVOKESTATIC, + "java/util/Objects", + "requireNonNull", + "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitInsn(POP) + } + } + } +} + +/** + * A member of a constructed class that can be assigned to null, + * + * maps to AMQP mandatory = false fields + */ +class NullableField(field: Class) : ClassField(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" + + constructor(name: String, field: Class) : this(field) { + if (field.isPrimitive) { + throw NullablePrimitiveException ( + "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") + } + + this.name = name + } + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + } +} + +/** + * Represents enum constants within an enum + */ +class EnumField: Field(Enum::class.java) { + override var descriptor : String? = null + + override val type: String + get() = "Ljava/lang/Enum;" + + override fun generateField(cw: ClassWriter) { + cw.visitField(ACC_PUBLIC + ACC_FINAL + ACC_STATIC + ACC_ENUM, name, + descriptor, null, null).visitEnd() + } + + override fun visitParameter(mv: MethodVisitor, idx: Int) { + mv.visitParameter(name, 0) + } +} + +/** + * Constructs a Field Schema object of the correct type depending weather + * the AMQP schema indicates it's mandatory (non nullable) or not (nullable) + */ +object FieldFactory { + fun newInstance (mandatory: Boolean, name: String, field: Class) = + if (mandatory) NonNullableField (name, field) else NullableField (name, field) + +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt index 9ffef1ebd1..c7b5fd5384 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt @@ -15,12 +15,12 @@ class ClassCarpenterTest { val b: Int } - val cc = ClassCarpenter() + private val cc = ClassCarpenter() // We have to ignore synthetic fields even though ClassCarpenter doesn't create any because the JaCoCo // coverage framework auto-magically injects one method and one field into every class loaded into the JVM. - val Class<*>.nonSyntheticFields: List get() = declaredFields.filterNot { it.isSynthetic } - val Class<*>.nonSyntheticMethods: List get() = declaredMethods.filterNot { it.isSynthetic } + private val Class<*>.nonSyntheticFields: List get() = declaredFields.filterNot { it.isSynthetic } + private val Class<*>.nonSyntheticMethods: List get() = declaredMethods.filterNot { it.isSynthetic } @Test fun empty() { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt index b7f1c2d348..a59858bead 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt @@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.serialization.carpenter import net.corda.nodeapi.internal.serialization.amqp.* import net.corda.nodeapi.internal.serialization.amqp.Field import net.corda.nodeapi.internal.serialization.amqp.Schema +import net.corda.nodeapi.internal.serialization.AllWhitelist fun mangleName(name: String) = "${name}__carpenter" @@ -34,7 +35,8 @@ fun Schema.mangleNames(names: List): Schema { } open class AmqpCarpenterBase { - var factory = testDefaultFactory() + var cc = ClassCarpenter() + var factory = SerializerFactory(AllWhitelist, cc.classloader) fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz) fun testName(): String = Thread.currentThread().stackTrace[2].methodName diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt new file mode 100644 index 0000000000..13c0ced089 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt @@ -0,0 +1,105 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class EnumClassTests : AmqpCarpenterBase() { + + @Test + fun oneValue() { + val enumConstants = mapOf ("A" to EnumField()) + + val schema = EnumSchema("gen.enum", enumConstants) + + cc.build(schema) + } + + @Test + fun oneValueInstantiate() { + val enumConstants = mapOf ("A" to EnumField()) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + assertEquals("A", clazz.enumConstants.first().toString()) + assertEquals(0, (clazz.enumConstants.first() as Enum<*>).ordinal) + assertEquals("A", (clazz.enumConstants.first() as Enum<*>).name) + } + + @Test + fun twoValuesInstantiate() { + val enumConstants = mapOf ("left" to EnumField(), "right" to EnumField()) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + + val left = clazz.enumConstants[0] as Enum<*> + val right = clazz.enumConstants[1] as Enum<*> + + assertEquals(0, left.ordinal) + assertEquals("left", left.name) + assertEquals(1, right.ordinal) + assertEquals("right", right.name) + } + + @Test + fun manyValues() { + val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF", + "GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() }) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + + var idx = 0 + enumConstants.forEach { + val constant = clazz.enumConstants[idx] as Enum<*> + assertEquals(idx++, constant.ordinal) + assertEquals(it.key, constant.name) + } + } + + @Test + fun assignment() { + val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() }) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertEquals("CCC", clazz.getMethod("valueOf", String::class.java).invoke (null, "CCC").toString()) + assertEquals("CCC", (clazz.getMethod("valueOf", String::class.java).invoke (null, "CCC") as Enum<*>).name) + + val ddd = clazz.getMethod("valueOf", String::class.java).invoke (null, "DDD") as Enum<*> + + assertTrue (ddd::class.java.isEnum) + assertEquals("DDD", ddd.name) + assertEquals(3, ddd.ordinal) + } + + // if anything goes wrong with this test it's going to end up throwing *some* + // exception, hence the lack of asserts + @Test + fun assignAndTest() { + val cc2 = ClassCarpenter() + + val schema1 = EnumSchema("gen.enum", + listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() })) + + val enumClazz = cc2.build(schema1) + + val schema2 = ClassSchema ("gen.class", + mapOf( + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(enumClazz))) + + val classClazz = cc2.build(schema2) + + // make sure we can construct a class that has an enum we've constructed as a member + classClazz.constructors[0].newInstance(1, enumClazz.getMethod( + "valueOf", String::class.java).invoke(null, "BBB")) + } +} \ No newline at end of file