diff --git a/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt b/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt index bc29f684ed..112426729b 100644 --- a/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt +++ b/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt @@ -77,8 +77,13 @@ class ClassCarpenter { fun fieldsIncludingSuperclasses(): Map> = (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(descriptors) + + val jvmName : String + get() = name.replace (".", "/") } + private val String.jvm: String get() = replace(".", "/") + class ClassSchema( name: String, fields: Map>, @@ -106,8 +111,6 @@ class ClassCarpenter { /** Returns a snapshot of the currently loaded classes as a map of full class name (package names+dots) -> class object */ val loaded: Map> = HashMap(_loaded) - private val String.jvm: String get() = replace(".", "/") - /** * Generate bytecode for the given schema and load into the JVM. The returned class object can be used to * construct instances of the generated class. @@ -125,52 +128,59 @@ class ClassCarpenter { hierarchy += cursor cursor = cursor.superclass } + hierarchy.reversed().forEach { when (it) { - is ClassSchema -> generateClass(it) is InterfaceSchema -> generateInterface(it) + is ClassSchema -> generateClass(it) } } + return _loaded[schema.name]!! } - private fun generateClass(schema: ClassSchema): Class<*> { - val jvmName = schema.name.jvm - // Lazy: we could compute max locals/max stack ourselves, it'd be faster. - val cw = ClassWriter(ClassWriter.COMPUTE_FRAMES or ClassWriter.COMPUTE_MAXS) - with(cw) { - // public class Name implements SimpleFieldAccess { - val superName = schema.superclass?.name?.jvm ?: "java/lang/Object" - val interfaces = arrayOf(SimpleFieldAccess::class.java.name.jvm) + schema.interfaces.map { it.name.jvm } - visit(52, ACC_PUBLIC + ACC_SUPER, jvmName, null, superName, interfaces) - generateFields(schema) - generateConstructor(jvmName, schema) - generateGetters(jvmName, schema) - if (schema.superclass == null) - generateGetMethod() // From SimplePropertyAccess - generateToString(jvmName, schema) - visitEnd() - } - val clazz = classloader.load(schema.name, cw.toByteArray()) - _loaded[schema.name] = clazz - return clazz - } - - private fun generateInterface(schema: InterfaceSchema): Class<*> { - val jvmName = schema.name.jvm - // Lazy: we could compute max locals/max stack ourselves, it'd be faster. - val cw = ClassWriter (ClassWriter.COMPUTE_FRAMES or ClassWriter.COMPUTE_MAXS) - with(cw) { + private fun generateInterface (schema: Schema): Class<*> { + return generate (schema) { cw, schema -> val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray() - visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, jvmName, null, "java/lang/Object", interfaces) + with (cw) { + visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces) - generateAbstractGetters(schema) + generateAbstractGetters(schema) - visitEnd() + visitEnd() + } } - val clazz = classloader.load(schema.name, cw.toByteArray()) + } + private fun generateClass (schema: Schema): Class<*> { + return generate (schema) { cw, schema -> + val superName = schema.superclass?.jvmName ?: "java/lang/Object" + val interfaces = arrayOf(SimpleFieldAccess::class.java.name.jvm) + schema.interfaces.map { it.name.jvm } + + with (cw) { + visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces) + + generateFields(schema) + generateConstructor(schema) + generateGetters(schema) + if (schema.superclass == null) + generateGetMethod() // From SimplePropertyAccess + generateToString(schema) + + visitEnd() + } + } + + } + + private fun generate(schema: Schema, generator : (ClassWriter, Schema) -> Unit): Class<*> { + // Lazy: we could compute max locals/max stack ourselves, it'd be faster. + val cw = ClassWriter (ClassWriter.COMPUTE_FRAMES or ClassWriter.COMPUTE_MAXS) + + generator (cw, schema) + + val clazz = classloader.load(schema.name, cw.toByteArray()) _loaded[schema.name] = clazz return clazz } @@ -181,7 +191,7 @@ class ClassCarpenter { } } - private fun ClassWriter.generateToString(jvmName: String, schema: Schema) { + private fun ClassWriter.generateToString(schema: Schema) { val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", "", null)) { visitCode() @@ -192,7 +202,7 @@ class ClassCarpenter { for ((name, type) in schema.fieldsIncludingSuperclasses().entries) { visitLdcInsn(name) visitVarInsn(ALOAD, 0) // this - visitFieldInsn(GETFIELD, jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) + visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) val desc = if (type.isPrimitive) schema.descriptors[name] else "Ljava/lang/Object;" visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;$desc)L$toStringHelper;", false) } @@ -220,13 +230,13 @@ class ClassCarpenter { } } - private fun ClassWriter.generateGetters(jvmName: String, schema: Schema) { + private fun ClassWriter.generateGetters(schema: Schema) { for ((name, type) in schema.fields) { val descriptor = schema.descriptors[name] with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + descriptor, null, null)) { visitCode() visitVarInsn(ALOAD, 0) // Load 'this' - visitFieldInsn(GETFIELD, jvmName, name, descriptor) + visitFieldInsn(GETFIELD, schema.jvmName, name, descriptor) when (type) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, TYPE -> visitInsn(IRETURN) java.lang.Long.TYPE -> visitInsn(LRETURN) @@ -240,7 +250,7 @@ class ClassCarpenter { } } - private fun ClassWriter.generateAbstractGetters(schema: InterfaceSchema) { + private fun ClassWriter.generateAbstractGetters(schema: Schema) { for ((name, type) in schema.fields) { val descriptor = schema.descriptors[name] val opcodes = ACC_ABSTRACT + ACC_PUBLIC @@ -251,7 +261,7 @@ class ClassCarpenter { } } - private fun ClassWriter.generateConstructor(jvmName: String, schema: Schema) { + private fun ClassWriter.generateConstructor(schema: Schema) { with(visitMethod(ACC_PUBLIC, "", "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", null, null)) { visitCode() // Calculate the super call. @@ -273,7 +283,7 @@ class ClassCarpenter { throw UnsupportedOperationException("Array types are not implemented yet") visitVarInsn(ALOAD, 0) // Load 'this' onto the stack slot += load(slot, type) // Load the contents of the parameter onto the stack. - visitFieldInsn(PUTFIELD, jvmName, name, schema.descriptors[name]) + visitFieldInsn(PUTFIELD, schema.jvmName, name, schema.descriptors[name]) } visitInsn(RETURN) visitMaxs(0, 0) diff --git a/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt b/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt index 2e6b6bebd8..27fe55ea2d 100644 --- a/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt +++ b/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt @@ -137,12 +137,12 @@ class ClassCarpenterTest { val schema1 = ClassCarpenter.InterfaceSchema("gen.Interface", mapOf("a" to Int::class.java)) val iface = cc.build(schema1) - assert (iface.isInterface()) - assert (iface.constructors.isEmpty()) - assertEquals (iface.declaredMethods.size, 1) - assertEquals (iface.declaredMethods[0].name, "getA") + assert(iface.isInterface()) + assert(iface.constructors.isEmpty()) + assertEquals(iface.declaredMethods.size, 1) + assertEquals(iface.declaredMethods[0].name, "getA") - val schema2 = ClassCarpenter.ClassSchema("gen.Derived", mapOf("a" to Int::class.java), interfaces = listOf (iface)) + val schema2 = ClassCarpenter.ClassSchema("gen.Derived", mapOf("a" to Int::class.java), interfaces = listOf(iface)) val clazz = cc.build(schema2) val testA = 42 val i = clazz.constructors[0].newInstance(testA) as SimpleFieldAccess @@ -162,7 +162,7 @@ class ClassCarpenterTest { "b" to String::class.java, "c" to Int::class.java, "d" to String::class.java), - interfaces = listOf (cc.build (iFace1), cc.build (iFace2))) + interfaces = listOf(cc.build(iFace1), cc.build(iFace2))) val clazz = cc.build(class1) val testA = 42 @@ -181,7 +181,7 @@ class ClassCarpenterTest { fun `interface implementing interface`() { val iFace1 = ClassCarpenter.InterfaceSchema( "gen.Interface1", - mapOf ( + mapOf( "a" to Int::class.java, "b" to String::class.java)) @@ -190,7 +190,7 @@ class ClassCarpenterTest { mapOf( "c" to Int::class.java, "d" to String::class.java), - interfaces = listOf (cc.build (iFace1))) + interfaces = listOf(cc.build(iFace1))) val class1 = ClassCarpenter.ClassSchema( "gen.Derived", @@ -199,7 +199,7 @@ class ClassCarpenterTest { "b" to String::class.java, "c" to Int::class.java, "d" to String::class.java), - interfaces = listOf (cc.build (iFace2))) + interfaces = listOf(cc.build(iFace2))) val clazz = cc.build(class1) val testA = 99