diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index d02bc0fea3..87fa1b2f29 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,8 @@ from the previous milestone release. UNRELEASED ---------- +* Adding enum support to the class carpenter + * ``ContractState::contract`` has been moved ``TransactionState::contract`` and it's type has changed to ``String`` in order to support dynamic classloading of contract and contract constraints. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index f62f5ca1b1..28019aab42 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -20,11 +20,21 @@ interface SimpleFieldAccess { operator fun get(name: String): Any? } -class CarpenterClassLoader (parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : +class CarpenterClassLoader(parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : ClassLoader(parentClassLoader) { fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) } +/** + * Which version of the java runtime are we constructing objects against + */ +private const val TARGET_VERSION = V1_8 + +private val jlEnum get() = Type.getInternalName(Enum::class.java) +private val jlString get() = Type.getInternalName(String::class.java) +private val jlObject get() = Type.getInternalName(Object::class.java) +private val jlClass get() = Type.getInternalName(Class::class.java) + /** * A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader. * The generated classes have getters, a toString method and implement a simple property access interface. The @@ -107,49 +117,66 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader when (it) { is InterfaceSchema -> generateInterface(it) is ClassSchema -> generateClass(it) + is EnumSchema -> generateEnum(it) } } - assert (schema.name in _loaded) + assert(schema.name in _loaded) return _loaded[schema.name]!! } + private fun generateEnum(enumSchema: Schema): Class<*> { + return generate(enumSchema) { cw, schema -> + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_FINAL + ACC_SUPER + ACC_ENUM, schema.jvmName, + "L$jlEnum;", jlEnum, null) + + visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() + generateFields(schema) + generateStaticEnumConstructor(schema) + generateEnumConstructor() + generateEnumValues(schema) + generateEnumValueOf(schema) + }.visitEnd() + } + } + private fun generateInterface(interfaceSchema: Schema): Class<*> { return generate(interfaceSchema) { cw, schema -> val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray() - with(cw) { - visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces) + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, + jlObject, interfaces) visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() generateAbstractGetters(schema) - - visitEnd() - } + }.visitEnd() } } private fun generateClass(classSchema: Schema): Class<*> { return generate(classSchema) { cw, schema -> - val superName = schema.superclass?.jvmName ?: "java/lang/Object" + val superName = schema.superclass?.jvmName ?: jlObject val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() - if (SimpleFieldAccess::class.java !in schema.interfaces) interfaces.add(SimpleFieldAccess::class.java.name.jvm) + if (SimpleFieldAccess::class.java !in schema.interfaces) { + interfaces.add(SimpleFieldAccess::class.java.name.jvm) + } - with(cw) { - visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray()) + cw.apply { + visit(TARGET_VERSION, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, + interfaces.toTypedArray()) visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd() generateFields(schema) - generateConstructor(schema) + generateClassConstructor(schema) generateGetters(schema) if (schema.superclass == null) generateGetMethod() // From SimplePropertyAccess generateToString(schema) - - visitEnd() - } + }.visitEnd() } } @@ -165,25 +192,26 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateFields(schema: Schema) { - schema.fields.forEach { it.value.generateField(this) } + schema.generateFields(this) } private fun ClassWriter.generateToString(schema: Schema) { val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" - with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", null, null)) { + with(visitMethod(ACC_PUBLIC, "toString", "()L$jlString;", null, null)) { visitCode() // com.google.common.base.MoreObjects.toStringHelper("TypeName") visitLdcInsn(schema.name.split('.').last()) - visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", "(Ljava/lang/String;)L$toStringHelper;", false) + visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", + "(L$jlString;)L$toStringHelper;", false) // Call the add() methods. for ((name, field) in schema.fieldsIncludingSuperclasses().entries) { visitLdcInsn(name) visitVarInsn(ALOAD, 0) // this visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) - visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;${field.type})L$toStringHelper;", false) + visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(L$jlString;${field.type})L$toStringHelper;", false) } // call toString() on the builder and return. - visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()Ljava/lang/String;", false) + visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()L$jlString;", false) visitInsn(ARETURN) visitMaxs(0, 0) visitEnd() @@ -192,14 +220,14 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader private fun ClassWriter.generateGetMethod() { val ourJvmName = ClassCarpenter::class.java.name.jvm - with(visitMethod(ACC_PUBLIC, "get", "(Ljava/lang/String;)Ljava/lang/Object;", null, null)) { + with(visitMethod(ACC_PUBLIC, "get", "(L$jlString;)L$jlObject;", null, null)) { visitCode() visitVarInsn(ALOAD, 0) // Load 'this' visitVarInsn(ALOAD, 1) // Load the name argument // Using this generic helper method is slow, as it relies on reflection. A faster way would be // to use a tableswitch opcode, or just push back on the user and ask them to use actual reflection // or MethodHandles (super fast reflection) to access the object instead. - visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(L$jlObject;L$jlString;)L$jlObject;", false) visitInsn(ARETURN) visitMaxs(0, 0) visitEnd() @@ -207,45 +235,113 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } private fun ClassWriter.generateGetters(schema: Schema) { - for ((name, type) in schema.fields) { - with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null)) { + @Suppress("UNCHECKED_CAST") + for ((name, type) in (schema.fields as Map)) { + visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null).apply { type.addNullabilityAnnotation(this) visitCode() visitVarInsn(ALOAD, 0) // Load 'this' visitFieldInsn(GETFIELD, schema.jvmName, name, type.descriptor) when (type.field) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, - java.lang.Character.TYPE -> visitInsn(IRETURN) + java.lang.Character.TYPE -> visitInsn(IRETURN) java.lang.Long.TYPE -> visitInsn(LRETURN) java.lang.Double.TYPE -> visitInsn(DRETURN) java.lang.Float.TYPE -> visitInsn(FRETURN) else -> visitInsn(ARETURN) } visitMaxs(0, 0) - visitEnd() - } + }.visitEnd() } } private fun ClassWriter.generateAbstractGetters(schema: Schema) { - for ((name, field) in schema.fields) { + @Suppress("UNCHECKED_CAST") + for ((name, field) in (schema.fields as Map)) { val opcodes = ACC_ABSTRACT + ACC_PUBLIC - with(visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null)) { - // abstract method doesn't have any implementation so just end - visitEnd() - } + // abstract method doesn't have any implementation so just end + visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null).visitEnd() } } - private fun ClassWriter.generateConstructor(schema: Schema) { - with(visitMethod( + private fun ClassWriter.generateStaticEnumConstructor(schema: Schema) { + visitMethod(ACC_STATIC, "", "()V", null, null).apply { + visitCode() + visitIntInsn(BIPUSH, schema.fields.size) + visitTypeInsn(ANEWARRAY, schema.jvmName) + visitInsn(DUP) + + var idx = 0 + schema.fields.forEach { + visitInsn(DUP) + visitIntInsn(BIPUSH, idx) + visitTypeInsn(NEW, schema.jvmName) + visitInsn(DUP) + visitLdcInsn(it.key) + visitIntInsn(BIPUSH, idx++) + visitMethodInsn(INVOKESPECIAL, schema.jvmName, "", "(L$jlString;I)V", false) + visitInsn(DUP) + visitFieldInsn(PUTSTATIC, schema.jvmName, it.key, "L${schema.jvmName};") + visitInsn(AASTORE) + } + + visitFieldInsn(PUTSTATIC, schema.jvmName, "\$VALUES", schema.asArray) + visitInsn(RETURN) + + visitMaxs(0, 0) + }.visitEnd() + } + + private fun ClassWriter.generateEnumValues(schema: Schema) { + visitMethod(ACC_PUBLIC + ACC_STATIC, "values", "()${schema.asArray}", null, null).apply { + visitCode() + visitFieldInsn(GETSTATIC, schema.jvmName, "\$VALUES", schema.asArray) + visitMethodInsn(INVOKEVIRTUAL, schema.asArray, "clone", "()L$jlObject;", false) + visitTypeInsn(CHECKCAST, schema.asArray) + visitInsn(ARETURN) + visitMaxs(0, 0) + }.visitEnd() + } + + private fun ClassWriter.generateEnumValueOf(schema: Schema) { + visitMethod(ACC_PUBLIC + ACC_STATIC, "valueOf", "(L$jlString;)L${schema.jvmName};", null, null).apply { + visitCode() + visitLdcInsn(Type.getType("L${schema.jvmName};")) + visitVarInsn(ALOAD, 0) + visitMethodInsn(INVOKESTATIC, jlEnum, "valueOf", "(L$jlClass;L$jlString;)L$jlEnum;", true) + visitTypeInsn(CHECKCAST, schema.jvmName) + visitInsn(ARETURN) + visitMaxs(0, 0) + }.visitEnd() + + } + + private fun ClassWriter.generateEnumConstructor() { + visitMethod(ACC_PROTECTED, "", "(L$jlString;I)V", "()V", null).apply { + visitParameter("\$enum\$name", ACC_SYNTHETIC) + visitParameter("\$enum\$ordinal", ACC_SYNTHETIC) + + visitCode() + + visitVarInsn(ALOAD, 0) // this + visitVarInsn(ALOAD, 1) + visitVarInsn(ILOAD, 2) + visitMethodInsn(INVOKESPECIAL, jlEnum, "", "(L$jlString;I)V", false) + visitInsn(RETURN) + + visitMaxs(0, 0) + }.visitEnd() + } + + private fun ClassWriter.generateClassConstructor(schema: Schema) { + visitMethod( ACC_PUBLIC, "", "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", null, - null)) - { + null).apply { var idx = 0 + schema.fields.values.forEach { it.visitParameter(this, idx++) } visitCode() @@ -255,7 +351,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader visitVarInsn(ALOAD, 0) val sc = schema.superclass if (sc == null) { - visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) + visitMethodInsn(INVOKESPECIAL, jlObject, "", "()V", false) } else { var slot = 1 superclassFields.values.forEach { slot += load(slot, it) } @@ -265,7 +361,8 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader // Assign the fields from parameters. var slot = 1 + superclassFields.size - for ((name, field) in schema.fields.entries) { + @Suppress("UNCHECKED_CAST") + for ((name, field) in (schema.fields as Map)) { field.nullTest(this, slot) visitVarInsn(ALOAD, 0) // Load 'this' onto the stack @@ -274,14 +371,13 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } visitInsn(RETURN) visitMaxs(0, 0) - visitEnd() - } + }.visitEnd() } private fun MethodVisitor.load(slot: Int, type: Field): Int { when (type.field) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, - java.lang.Character.TYPE -> visitVarInsn(ILOAD, slot) + java.lang.Character.TYPE -> visitVarInsn(ILOAD, slot) java.lang.Long.TYPE -> visitVarInsn(LLOAD, slot) java.lang.Double.TYPE -> visitVarInsn(DLOAD, slot) java.lang.Float.TYPE -> visitVarInsn(FLOAD, slot) @@ -325,7 +421,8 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader } companion object { - @JvmStatic @Suppress("UNUSED") + @JvmStatic + @Suppress("UNUSED") fun getField(obj: Any, name: String): Any? = obj.javaClass.getMethod("get" + name.capitalize()).invoke(obj) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt index cfa2f2a4e8..c96ae86e91 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt @@ -1,11 +1,11 @@ package net.corda.nodeapi.internal.serialization.carpenter -class DuplicateNameException : RuntimeException ( +class DuplicateNameException : RuntimeException( "An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") class InterfaceMismatchException(msg: String) : RuntimeException(msg) class NullablePrimitiveException(msg: String) : RuntimeException(msg) -class UncarpentableException (name: String, field: String, type: String) : - Exception ("Class $name is loadable yet contains field $field of unknown type $type") +class UncarpentableException(name: String, field: String, type: String) : + Exception("Class $name is loadable yet contains field $field of unknown type $type") diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt index fb0d1f7001..c678dcadde 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt @@ -22,22 +22,19 @@ import net.corda.nodeapi.internal.serialization.amqp.TypeNotation * in turn look up all of those classes in the [dependsOn] list, remove their dependency on the newly created class, * and if that list is reduced to zero know we can now generate a [Schema] for them and carpent them up */ -data class CarpenterSchemas ( +data class CarpenterSchemas( val carpenterSchemas: MutableList, val dependencies: MutableMap>>, val dependsOn: MutableMap>) { companion object CarpenterSchemaConstructor { fun newInstance(): CarpenterSchemas { - return CarpenterSchemas( - mutableListOf(), - mutableMapOf>>(), - mutableMapOf>()) + return CarpenterSchemas(mutableListOf(), mutableMapOf(), mutableMapOf()) } } fun addDepPair(type: TypeNotation, dependant: String, dependee: String) { - dependsOn.computeIfAbsent(dependee, { mutableListOf() }).add(dependant) - dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf()) }).second.add(dependee) + dependsOn.computeIfAbsent(dependee, { mutableListOf() }).add(dependant) + dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf()) }).second.add(dependee) } val size @@ -56,23 +53,23 @@ data class CarpenterSchemas ( * @property cc a reference to the actual class carpenter we're using to constuct classes * @property objects a list of carpented classes loaded into the carpenters class loader */ -abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : ClassCarpenter = ClassCarpenter()) { +abstract class MetaCarpenterBase(val schemas: CarpenterSchemas, val cc: ClassCarpenter = ClassCarpenter()) { val objects = mutableMapOf>() fun step(newObject: Schema) { - objects[newObject.name] = cc.build (newObject) + objects[newObject.name] = cc.build(newObject) // go over the list of everything that had a dependency on the newly // carpented class existing and remove it from their dependency list, If that // list is now empty we have no impediment to carpenting that class up schemas.dependsOn.remove(newObject.name)?.forEach { dependent -> - assert (newObject.name in schemas.dependencies[dependent]!!.second) + assert(newObject.name in schemas.dependencies[dependent]!!.second) schemas.dependencies[dependent]?.second?.remove(newObject.name) // we're out of blockers so we can now create the type - if (schemas.dependencies[dependent]?.second?.isEmpty() ?: false) { - (schemas.dependencies.remove (dependent)?.first as CompositeType).carpenterSchema ( + if (schemas.dependencies[dependent]?.second?.isEmpty() == true) { + (schemas.dependencies.remove(dependent)?.first as CompositeType).carpenterSchema( classloader = cc.classloader, carpenterSchemas = schemas) } @@ -81,25 +78,25 @@ abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : Class abstract fun build() - val classloader : ClassLoader - get() = cc.classloader + val classloader: ClassLoader + get() = cc.classloader } -class MetaCarpenter(schemas : CarpenterSchemas, - cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { +class MetaCarpenter(schemas: CarpenterSchemas, + cc: ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { override fun build() { while (schemas.carpenterSchemas.isNotEmpty()) { val newObject = schemas.carpenterSchemas.removeAt(0) - step (newObject) + step(newObject) } } } -class TestMetaCarpenter(schemas : CarpenterSchemas, - cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { +class TestMetaCarpenter(schemas: CarpenterSchemas, + cc: ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { override fun build() { if (schemas.carpenterSchemas.isEmpty()) return - step (schemas.carpenterSchemas.removeAt(0)) + step(schemas.carpenterSchemas.removeAt(0)) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt index a0753bbfe4..16d0e362d5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt @@ -1,148 +1,110 @@ package net.corda.nodeapi.internal.serialization.carpenter -import jdk.internal.org.objectweb.asm.Opcodes.* +import kotlin.collections.LinkedHashMap import org.objectweb.asm.ClassWriter -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Type -import java.util.* +import org.objectweb.asm.Opcodes.* /** - * A Schema represents a desired class. + * A Schema is the representation of an object the Carpenter can contsruct + * + * Known Sub Classes + * - [ClassSchema] + * - [InterfaceSchema] + * - [EnumSchema] */ abstract class Schema( val name: String, - fields: Map, + var fields: Map, val superclass: Schema? = null, - val interfaces: List> = emptyList()) -{ - private fun Map.descriptors() = - LinkedHashMap(this.mapValues { it.value.descriptor }) + val interfaces: List> = emptyList(), + updater: (String, Field) -> Unit) { + private fun Map.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor }) - /* Fix the order up front if the user didn't, inject the name into the field as it's - neater when iterating */ - val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) }) + init { + fields.forEach { updater(it.key, it.value) } + + // Fix the order up front if the user didn't, inject the name into the field as it's + // neater when iterating + fields = LinkedHashMap(fields) + } fun fieldsIncludingSuperclasses(): Map = (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) - fun descriptorsIncludingSuperclasses(): Map = + fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors() + abstract fun generateFields(cw: ClassWriter) + val jvmName: String get() = name.replace(".", "/") + + val asArray: String + get() = "[L$jvmName;" } +/** + * Represents a concrete object + */ class ClassSchema( name: String, fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() -) : Schema(name, fields, superclass, interfaces) +) : Schema(name, fields, superclass, interfaces, { name, field -> field.name = name }) { + override fun generateFields(cw: ClassWriter) { + cw.apply { fields.forEach { it.value.generateField(this) } } + } +} +/** + * Represents an interface. Carpented interfaces can be used within [ClassSchema]s + * if that class should be implementing that interface + */ class InterfaceSchema( name: String, fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() -) : Schema(name, fields, superclass, interfaces) +) : Schema(name, fields, superclass, interfaces, { name, field -> field.name = name }) { + override fun generateFields(cw: ClassWriter) { + cw.apply { fields.forEach { it.value.generateField(this) } } + } +} +/** + * Represents an enumerated type + */ +class EnumSchema( + name: String, + fields: Map +) : Schema(name, fields, null, emptyList(), { fieldName, field -> + (field as EnumField).name = fieldName + field.descriptor = "L${name.replace(".", "/")};" +}) { + override fun generateFields(cw: ClassWriter) { + with(cw) { + fields.forEach { it.value.generateField(this) } + + visitField(ACC_PRIVATE + ACC_FINAL + ACC_STATIC + ACC_SYNTHETIC, + "\$VALUES", asArray, null, null) + } + } +} + +/** + * Factory object used by the serialiser when building [Schema]s based + * on an AMQP schema + */ object CarpenterSchemaFactory { - fun newInstance ( + fun newInstance( name: String, fields: Map, superclass: Schema? = null, interfaces: List> = emptyList(), isInterface: Boolean = false - ) : Schema = - if (isInterface) InterfaceSchema (name, fields, superclass, interfaces) - else ClassSchema (name, fields, superclass, interfaces) + ): Schema = + if (isInterface) InterfaceSchema(name, fields, superclass, interfaces) + else ClassSchema(name, fields, superclass, interfaces) } -abstract class Field(val field: Class) { - companion object { - const val unsetName = "Unset" - } - - var name: String = unsetName - abstract val nullabilityAnnotation: String - - val descriptor: String - get() = Type.getDescriptor(this.field) - - val type: String - get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" - - fun generateField(cw: ClassWriter) { - val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null) - fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd() - fieldVisitor.visitEnd() - } - - fun addNullabilityAnnotation(mv: MethodVisitor) { - mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() - } - - fun visitParameter(mv: MethodVisitor, idx: Int) { - with(mv) { - visitParameter(name, 0) - if (!field.isPrimitive) { - visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() - } - } - } - - abstract fun copy(name: String, field: Class): Field - abstract fun nullTest(mv: MethodVisitor, slot: Int) -} - -class NonNullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" - - constructor(name: String, field: Class) : this(field) { - this.name = name - } - - override fun copy(name: String, field: Class) = NonNullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - - if (!field.isPrimitive) { - with(mv) { - visitVarInsn(ALOAD, 0) // load this - visitVarInsn(ALOAD, slot) // load parameter - visitLdcInsn("param \"$name\" cannot be null") - visitMethodInsn(INVOKESTATIC, - "java/util/Objects", - "requireNonNull", - "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) - visitInsn(POP) - } - } - } -} - -class NullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" - - constructor(name: String, field: Class) : this(field) { - if (field.isPrimitive) { - throw NullablePrimitiveException ( - "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") - } - - this.name = name - } - - override fun copy(name: String, field: Class) = NullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - } -} - -object FieldFactory { - fun newInstance (mandatory: Boolean, name: String, field: Class) = - if (mandatory) NonNullableField (name, field) else NullableField (name, field) - -} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt new file mode 100644 index 0000000000..33ab42dcd4 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SchemaFields.kt @@ -0,0 +1,138 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import jdk.internal.org.objectweb.asm.Opcodes.* +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Type + +abstract class Field(val field: Class) { + abstract var descriptor: String? + + companion object { + const val unsetName = "Unset" + } + + var name: String = unsetName + abstract val type: String + + abstract fun generateField(cw: ClassWriter) + abstract fun visitParameter(mv: MethodVisitor, idx: Int) +} + +/** + * Any field that can be a member of an object + * + * Known + * - [NullableField] + * - [NonNullableField] + */ +abstract class ClassField(field: Class) : Field(field) { + abstract val nullabilityAnnotation: String + abstract fun nullTest(mv: MethodVisitor, slot: Int) + + override var descriptor = Type.getDescriptor(this.field) + override val type: String get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" + + fun addNullabilityAnnotation(mv: MethodVisitor) { + mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() + } + + override fun generateField(cw: ClassWriter) { + cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null).visitAnnotation( + nullabilityAnnotation, true).visitEnd() + } + + override fun visitParameter(mv: MethodVisitor, idx: Int) { + with(mv) { + visitParameter(name, 0) + if (!field.isPrimitive) { + visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() + } + } + } +} + +/** + * A member of a constructed class that can be assigned to null, the + * mandatory type for primitives, but also any member that cannot be + * null + * + * maps to AMQP mandatory = true fields + */ +open class NonNullableField(field: Class) : ClassField(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" + + constructor(name: String, field: Class) : this(field) { + this.name = name + } + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + + if (!field.isPrimitive) { + with(mv) { + visitVarInsn(ALOAD, 0) // load this + visitVarInsn(ALOAD, slot) // load parameter + visitLdcInsn("param \"$name\" cannot be null") + visitMethodInsn(INVOKESTATIC, + "java/util/Objects", + "requireNonNull", + "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitInsn(POP) + } + } + } +} + +/** + * A member of a constructed class that can be assigned to null, + * + * maps to AMQP mandatory = false fields + */ +class NullableField(field: Class) : ClassField(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" + + constructor(name: String, field: Class) : this(field) { + this.name = name + } + + init { + if (field.isPrimitive) { + throw NullablePrimitiveException( + "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") + } + } + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + } +} + +/** + * Represents enum constants within an enum + */ +class EnumField : Field(Enum::class.java) { + override var descriptor: String? = null + + override val type: String + get() = "Ljava/lang/Enum;" + + override fun generateField(cw: ClassWriter) { + cw.visitField(ACC_PUBLIC + ACC_FINAL + ACC_STATIC + ACC_ENUM, name, + descriptor, null, null).visitEnd() + } + + override fun visitParameter(mv: MethodVisitor, idx: Int) { + mv.visitParameter(name, 0) + } +} + +/** + * Constructs a Field Schema object of the correct type depending weather + * the AMQP schema indicates it's mandatory (non nullable) or not (nullable) + */ +object FieldFactory { + fun newInstance(mandatory: Boolean, name: String, field: Class) = + if (mandatory) NonNullableField(name, field) else NullableField(name, field) + +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt index 9ffef1ebd1..c7b5fd5384 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt @@ -15,12 +15,12 @@ class ClassCarpenterTest { val b: Int } - val cc = ClassCarpenter() + private val cc = ClassCarpenter() // We have to ignore synthetic fields even though ClassCarpenter doesn't create any because the JaCoCo // coverage framework auto-magically injects one method and one field into every class loaded into the JVM. - val Class<*>.nonSyntheticFields: List get() = declaredFields.filterNot { it.isSynthetic } - val Class<*>.nonSyntheticMethods: List get() = declaredMethods.filterNot { it.isSynthetic } + private val Class<*>.nonSyntheticFields: List get() = declaredFields.filterNot { it.isSynthetic } + private val Class<*>.nonSyntheticMethods: List get() = declaredMethods.filterNot { it.isSynthetic } @Test fun empty() { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt index b7f1c2d348..a59858bead 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt @@ -3,6 +3,7 @@ package net.corda.nodeapi.internal.serialization.carpenter import net.corda.nodeapi.internal.serialization.amqp.* import net.corda.nodeapi.internal.serialization.amqp.Field import net.corda.nodeapi.internal.serialization.amqp.Schema +import net.corda.nodeapi.internal.serialization.AllWhitelist fun mangleName(name: String) = "${name}__carpenter" @@ -34,7 +35,8 @@ fun Schema.mangleNames(names: List): Schema { } open class AmqpCarpenterBase { - var factory = testDefaultFactory() + var cc = ClassCarpenter() + var factory = SerializerFactory(AllWhitelist, cc.classloader) fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz) fun testName(): String = Thread.currentThread().stackTrace[2].methodName diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt new file mode 100644 index 0000000000..d4c40fcc0a --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/EnumClassTests.kt @@ -0,0 +1,105 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class EnumClassTests : AmqpCarpenterBase() { + + @Test + fun oneValue() { + val enumConstants = mapOf("A" to EnumField()) + + val schema = EnumSchema("gen.enum", enumConstants) + + assertTrue(cc.build(schema).isEnum) + } + + @Test + fun oneValueInstantiate() { + val enumConstants = mapOf("A" to EnumField()) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + assertEquals("A", clazz.enumConstants.first().toString()) + assertEquals(0, (clazz.enumConstants.first() as Enum<*>).ordinal) + assertEquals("A", (clazz.enumConstants.first() as Enum<*>).name) + } + + @Test + fun twoValuesInstantiate() { + val enumConstants = mapOf("left" to EnumField(), "right" to EnumField()) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + + val left = clazz.enumConstants[0] as Enum<*> + val right = clazz.enumConstants[1] as Enum<*> + + assertEquals(0, left.ordinal) + assertEquals("left", left.name) + assertEquals(1, right.ordinal) + assertEquals("right", right.name) + } + + @Test + fun manyValues() { + val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF", + "GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() }) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertTrue(clazz.isEnum) + assertEquals(enumConstants.size, clazz.enumConstants.size) + + var idx = 0 + enumConstants.forEach { + val constant = clazz.enumConstants[idx] as Enum<*> + assertEquals(idx++, constant.ordinal) + assertEquals(it.key, constant.name) + } + } + + @Test + fun assignment() { + val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() }) + val schema = EnumSchema("gen.enum", enumConstants) + val clazz = cc.build(schema) + + assertEquals("CCC", clazz.getMethod("valueOf", String::class.java).invoke(null, "CCC").toString()) + assertEquals("CCC", (clazz.getMethod("valueOf", String::class.java).invoke(null, "CCC") as Enum<*>).name) + + val ddd = clazz.getMethod("valueOf", String::class.java).invoke(null, "DDD") as Enum<*> + + assertTrue(ddd::class.java.isEnum) + assertEquals("DDD", ddd.name) + assertEquals(3, ddd.ordinal) + } + + // if anything goes wrong with this test it's going to end up throwing *some* + // exception, hence the lack of asserts + @Test + fun assignAndTest() { + val cc2 = ClassCarpenter() + + val schema1 = EnumSchema("gen.enum", + listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() })) + + val enumClazz = cc2.build(schema1) + + val schema2 = ClassSchema("gen.class", + mapOf( + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(enumClazz))) + + val classClazz = cc2.build(schema2) + + // make sure we can construct a class that has an enum we've constructed as a member + classClazz.constructors[0].newInstance(1, enumClazz.getMethod( + "valueOf", String::class.java).invoke(null, "BBB")) + } +} \ No newline at end of file