diff --git a/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt b/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt index 7641b9c6ea..534083325f 100644 --- a/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt +++ b/experimental/src/main/kotlin/net/corda/carpenter/ClassCarpenter.kt @@ -71,20 +71,40 @@ class ClassCarpenter { /** * A Schema represents a desired class. */ - class Schema(val name: String, fields: Map>, val superclass: Schema? = null, val interfaces: List> = emptyList()) { + open class Schema(val name: String, fields: Map>, val superclass: Schema? = null, val interfaces: List> = emptyList()) { val fields = LinkedHashMap(fields) // Fix the order up front if the user didn't. val descriptors = fields.map { it.key to Type.getDescriptor(it.value) }.toMap() fun fieldsIncludingSuperclasses(): Map> = (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) - fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(descriptors) + fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(descriptors) + + val jvmName: String + get() = name.replace(".", "/") } + private val String.jvm: String get() = replace(".", "/") + + class ClassSchema( + name: String, + fields: Map>, + superclass: Schema? = null, + interfaces: List> = emptyList() + ) : Schema(name, fields, superclass, interfaces) + + class InterfaceSchema( + name: String, + fields: Map>, + superclass: Schema? = null, + interfaces: List> = emptyList() + ) : Schema(name, fields, superclass, interfaces) + class DuplicateName : RuntimeException("An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") class InterfaceMismatch(msg: String) : RuntimeException(msg) private class CarpenterClassLoader : ClassLoader(Thread.currentThread().contextClassLoader) { fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) } + private val classloader = CarpenterClassLoader() private val _loaded = HashMap>() @@ -92,8 +112,6 @@ class ClassCarpenter { /** Returns a snapshot of the currently loaded classes as a map of full class name (package names+dots) -> class object */ val loaded: Map> = HashMap(_loaded) - private val String.jvm: String get() = replace(".", "/") - /** * Generate bytecode for the given schema and load into the JVM. The returned class object can be used to * construct instances of the generated class. @@ -111,27 +129,57 @@ class ClassCarpenter { hierarchy += cursor cursor = cursor.superclass } - hierarchy.reversed().forEach { generateClass(it) } + + hierarchy.reversed().forEach { + when (it) { + is InterfaceSchema -> generateInterface(it) + is ClassSchema -> generateClass(it) + } + } + return _loaded[schema.name]!! } + private fun generateInterface(schema: Schema): Class<*> { + return generate(schema) { cw, schema -> + val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray() + + with(cw) { + visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces) + + generateAbstractGetters(schema) + + visitEnd() + } + } + } + private fun generateClass(schema: Schema): Class<*> { - val jvmName = schema.name.jvm + return generate(schema) { cw, schema -> + val superName = schema.superclass?.jvmName ?: "java/lang/Object" + val interfaces = arrayOf(SimpleFieldAccess::class.java.name.jvm) + schema.interfaces.map { it.name.jvm } + + with(cw) { + visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces) + + generateFields(schema) + generateConstructor(schema) + generateGetters(schema) + if (schema.superclass == null) + generateGetMethod() // From SimplePropertyAccess + generateToString(schema) + + visitEnd() + } + } + } + + private fun generate(schema: Schema, generator: (ClassWriter, Schema) -> Unit): Class<*> { // Lazy: we could compute max locals/max stack ourselves, it'd be faster. val cw = ClassWriter(ClassWriter.COMPUTE_FRAMES or ClassWriter.COMPUTE_MAXS) - with(cw) { - // public class Name implements SimpleFieldAccess { - val superName = schema.superclass?.name?.jvm ?: "java/lang/Object" - val interfaces = arrayOf(SimpleFieldAccess::class.java.name.jvm) + schema.interfaces.map { it.name.jvm } - visit(52, ACC_PUBLIC + ACC_SUPER, jvmName, null, superName, interfaces) - generateFields(schema) - generateConstructor(jvmName, schema) - generateGetters(jvmName, schema) - if (schema.superclass == null) - generateGetMethod() // From SimplePropertyAccess - generateToString(jvmName, schema) - visitEnd() - } + + generator(cw, schema) + val clazz = classloader.load(schema.name, cw.toByteArray()) _loaded[schema.name] = clazz return clazz @@ -143,7 +191,7 @@ class ClassCarpenter { } } - private fun ClassWriter.generateToString(jvmName: String, schema: Schema) { + private fun ClassWriter.generateToString(schema: Schema) { val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", "", null)) { visitCode() @@ -154,7 +202,7 @@ class ClassCarpenter { for ((name, type) in schema.fieldsIncludingSuperclasses().entries) { visitLdcInsn(name) visitVarInsn(ALOAD, 0) // this - visitFieldInsn(GETFIELD, jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) + visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) val desc = if (type.isPrimitive) schema.descriptors[name] else "Ljava/lang/Object;" visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;$desc)L$toStringHelper;", false) } @@ -182,13 +230,13 @@ class ClassCarpenter { } } - private fun ClassWriter.generateGetters(jvmName: String, schema: Schema) { + private fun ClassWriter.generateGetters(schema: Schema) { for ((name, type) in schema.fields) { val descriptor = schema.descriptors[name] with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + descriptor, null, null)) { visitCode() visitVarInsn(ALOAD, 0) // Load 'this' - visitFieldInsn(GETFIELD, jvmName, name, descriptor) + visitFieldInsn(GETFIELD, schema.jvmName, name, descriptor) when (type) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, TYPE -> visitInsn(IRETURN) java.lang.Long.TYPE -> visitInsn(LRETURN) @@ -202,7 +250,18 @@ class ClassCarpenter { } } - private fun ClassWriter.generateConstructor(jvmName: String, schema: Schema) { + private fun ClassWriter.generateAbstractGetters(schema: Schema) { + for ((name, _) in schema.fields) { + val descriptor = schema.descriptors[name] + val opcodes = ACC_ABSTRACT + ACC_PUBLIC + with(visitMethod(opcodes, "get" + name.capitalize(), "()" + descriptor, null, null)) { + // abstract method doesn't have any implementation so just end + visitEnd() + } + } + } + + private fun ClassWriter.generateConstructor(schema: Schema) { with(visitMethod(ACC_PUBLIC, "", "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", null, null)) { visitCode() // Calculate the super call. @@ -224,7 +283,7 @@ class ClassCarpenter { throw UnsupportedOperationException("Array types are not implemented yet") visitVarInsn(ALOAD, 0) // Load 'this' onto the stack slot += load(slot, type) // Load the contents of the parameter onto the stack. - visitFieldInsn(PUTFIELD, jvmName, name, schema.descriptors[name]) + visitFieldInsn(PUTFIELD, schema.jvmName, name, schema.descriptors[name]) } visitInsn(RETURN) visitMaxs(0, 0) @@ -257,13 +316,14 @@ class ClassCarpenter { // actually called, which is a bit too dynamic for my tastes. val allFields = schema.fieldsIncludingSuperclasses() for (itf in schema.interfaces) { - for (method in itf.methods) { + itf.methods.forEach { val fieldNameFromItf = when { - method.name.startsWith("get") -> method.name.substring(3).decapitalize() - else -> throw InterfaceMismatch("Requested interfaces must consist only of methods that start with 'get': ${itf.name}.${method.name}") + it.name.startsWith("get") -> it.name.substring(3).decapitalize() + else -> throw InterfaceMismatch("Requested interfaces must consist only of methods that start with 'get': ${itf.name}.${it.name}") } - if (fieldNameFromItf !in allFields) - throw InterfaceMismatch("Interface ${itf.name} requires a field named ${fieldNameFromItf} but that isn't found in the schema or any superclass schemas") + + if ((schema is ClassSchema) and (fieldNameFromItf !in allFields)) + throw InterfaceMismatch("Interface ${itf.name} requires a field named $fieldNameFromItf but that isn't found in the schema or any superclass schemas") } } } diff --git a/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt b/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt index 5ffe32e51d..1827c83d79 100644 --- a/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt +++ b/experimental/src/test/kotlin/net/corda/carpenter/ClassCarpenterTest.kt @@ -21,7 +21,7 @@ class ClassCarpenterTest { @Test fun empty() { - val clazz = cc.build(ClassCarpenter.Schema("gen.EmptyClass", emptyMap(), null)) + val clazz = cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) assertEquals(0, clazz.nonSyntheticFields.size) assertEquals(2, clazz.nonSyntheticMethods.size) // get, toString assertEquals(0, clazz.declaredConstructors[0].parameterCount) @@ -30,7 +30,7 @@ class ClassCarpenterTest { @Test fun prims() { - val clazz = cc.build(ClassCarpenter.Schema("gen.Prims", mapOf( + val clazz = cc.build(ClassCarpenter.ClassSchema("gen.Prims", mapOf( "anIntField" to Int::class.javaPrimitiveType!!, "aLongField" to Long::class.javaPrimitiveType!!, "someCharField" to Char::class.javaPrimitiveType!!, @@ -65,7 +65,7 @@ class ClassCarpenterTest { } private fun genPerson(): Pair, Any> { - val clazz = cc.build(ClassCarpenter.Schema("gen.Person", mapOf( + val clazz = cc.build(ClassCarpenter.ClassSchema("gen.Person", mapOf( "age" to Int::class.javaPrimitiveType!!, "name" to String::class.java ))) @@ -88,14 +88,14 @@ class ClassCarpenterTest { @Test(expected = ClassCarpenter.DuplicateName::class) fun duplicates() { - cc.build(ClassCarpenter.Schema("gen.EmptyClass", emptyMap(), null)) - cc.build(ClassCarpenter.Schema("gen.EmptyClass", emptyMap(), null)) + cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) + cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) } @Test fun `can refer to each other`() { val (clazz1, i) = genPerson() - val clazz2 = cc.build(ClassCarpenter.Schema("gen.Referee", mapOf( + val clazz2 = cc.build(ClassCarpenter.ClassSchema("gen.Referee", mapOf( "ref" to clazz1 ))) val i2 = clazz2.constructors[0].newInstance(i) @@ -104,8 +104,8 @@ class ClassCarpenterTest { @Test fun superclasses() { - val schema1 = ClassCarpenter.Schema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.Schema("gen.B", mapOf("b" to String::class.java), schema1) + val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) + val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("b" to String::class.java), schema1) val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", "xb") as SimpleFieldAccess assertEquals("xa", i["a"]) @@ -115,8 +115,8 @@ class ClassCarpenterTest { @Test fun interfaces() { - val schema1 = ClassCarpenter.Schema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.Schema("gen.B", mapOf("b" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) + val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) + val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("b" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface assertEquals("xa", i.a) @@ -125,10 +125,92 @@ class ClassCarpenterTest { @Test(expected = ClassCarpenter.InterfaceMismatch::class) fun `mismatched interface`() { - val schema1 = ClassCarpenter.Schema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.Schema("gen.B", mapOf("c" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) + val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) + val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("c" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface assertEquals(1, i.b) } -} \ No newline at end of file + + @Test + fun `generate interface`() { + val schema1 = ClassCarpenter.InterfaceSchema("gen.Interface", mapOf("a" to Int::class.java)) + val iface = cc.build(schema1) + + assert(iface.isInterface()) + assert(iface.constructors.isEmpty()) + assertEquals(iface.declaredMethods.size, 1) + assertEquals(iface.declaredMethods[0].name, "getA") + + val schema2 = ClassCarpenter.ClassSchema("gen.Derived", mapOf("a" to Int::class.java), interfaces = listOf(iface)) + val clazz = cc.build(schema2) + val testA = 42 + val i = clazz.constructors[0].newInstance(testA) as SimpleFieldAccess + + assertEquals(testA, i["a"]) + } + + @Test + fun `generate multiple interfaces`() { + val iFace1 = ClassCarpenter.InterfaceSchema("gen.Interface1", mapOf("a" to Int::class.java, "b" to String::class.java)) + val iFace2 = ClassCarpenter.InterfaceSchema("gen.Interface2", mapOf("c" to Int::class.java, "d" to String::class.java)) + + val class1 = ClassCarpenter.ClassSchema( + "gen.Derived", + mapOf( + "a" to Int::class.java, + "b" to String::class.java, + "c" to Int::class.java, + "d" to String::class.java), + interfaces = listOf(cc.build(iFace1), cc.build(iFace2))) + + val clazz = cc.build(class1) + val testA = 42 + val testB = "don't touch me, I'm scared" + val testC = 0xDEAD + val testD = "wibble" + val i = clazz.constructors[0].newInstance(testA, testB, testC, testD) as SimpleFieldAccess + + assertEquals(testA, i["a"]) + assertEquals(testB, i["b"]) + assertEquals(testC, i["c"]) + assertEquals(testD, i["d"]) + } + + @Test + fun `interface implementing interface`() { + val iFace1 = ClassCarpenter.InterfaceSchema( + "gen.Interface1", + mapOf( + "a" to Int::class.java, + "b" to String::class.java)) + + val iFace2 = ClassCarpenter.InterfaceSchema( + "gen.Interface2", + mapOf( + "c" to Int::class.java, + "d" to String::class.java), + interfaces = listOf(cc.build(iFace1))) + + val class1 = ClassCarpenter.ClassSchema( + "gen.Derived", + mapOf( + "a" to Int::class.java, + "b" to String::class.java, + "c" to Int::class.java, + "d" to String::class.java), + interfaces = listOf(cc.build(iFace2))) + + val clazz = cc.build(class1) + val testA = 99 + val testB = "green is not a creative colour" + val testC = 7 + val testD = "I like jam" + val i = clazz.constructors[0].newInstance(testA, testB, testC, testD) as SimpleFieldAccess + + assertEquals(testA, i["a"]) + assertEquals(testB, i["b"]) + assertEquals(testC, i["c"]) + assertEquals(testD, i["d"]) + } +}