From 6dc7f694e4dce6dae4cbf2bfee7f8c8c9a5b43fb Mon Sep 17 00:00:00 2001 From: Katelyn Baker Date: Thu, 29 Jun 2017 17:53:07 +0100 Subject: [PATCH] Add explicit support for nullable types Remove prohibition against non string object classes such as arrays Squashed Commmits: * Tidyup whitespace * WIP * Review Comments * WIP - adding concept of nullabltily into the carpenter * Add explicit nullable and non nullable fields * Rebase onto master, fix package names in carpenter --- .../serialization/carpenter/ClassCarpenter.kt | 168 +++++++-- .../carpenter/ClassCarpenterTest.kt | 325 ++++++++++++++++-- 2 files changed, 422 insertions(+), 71 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt b/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt index 534083325f..46bbafff3b 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt @@ -1,4 +1,4 @@ -package net.corda.carpenter +package net.corda.core.serialization.carpenter import org.objectweb.asm.ClassWriter import org.objectweb.asm.MethodVisitor @@ -60,23 +60,118 @@ interface SimpleFieldAccess { * * Equals/hashCode methods are not yet supported. */ + +fun Map.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor }) + class ClassCarpenter { - // TODO: Array types. // TODO: Generics. // TODO: Sandbox the generated code when a security manager is in use. // TODO: Generate equals/hashCode. // TODO: Support annotations. // TODO: isFoo getter patterns for booleans (this is what Kotlin generates) + + class DuplicateName : RuntimeException("An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") + class InterfaceMismatch(msg: String) : RuntimeException(msg) + class NullablePrimitive(msg: String) : RuntimeException(msg) + + abstract class Field(val field: Class) { + val unsetName = "Unset" + var name: String = unsetName + abstract val nullabilityAnnotation: String + + val descriptor: String + get() = Type.getDescriptor(this.field) + + val type: String + get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" + + fun generateField(cw: ClassWriter) { + val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null) + cw.visitAnnotation(nullabilityAnnotation, false).visitEnd() + fieldVisitor.visitEnd() + } + + fun addNullabilityAnnotation(mv: MethodVisitor) { + mv.visitAnnotation(nullabilityAnnotation, false) + } + + abstract fun copy(name: String, field: Class): Field + abstract fun nullTest(mv: MethodVisitor, slot: Int) + fun visitParameter(mv: MethodVisitor, idx: Int) { + with(mv) { + visitParameter(name, 0) + if (!field.isPrimitive) { + visitParameterAnnotation(idx, nullabilityAnnotation, false).visitEnd() + } + } + } + } + + class NonNullableField(field: Class) : Field(field) { + override val nullabilityAnnotation = "Lorg/jetbrains/annotations/Nullable;" + + constructor(name: String, field: Class) : this(field) { + this.name = name + } + + override fun copy(name: String, field: Class) = NonNullableField(name, field) + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + + if (!field.isPrimitive) { + with(mv) { + visitVarInsn(ALOAD, 0) // load this + visitVarInsn(ALOAD, slot) // load parameter + visitLdcInsn("param \"$name\" cannot be null") + visitMethodInsn(INVOKESTATIC, + "java/util/Objects", + "requireNonNull", + "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitInsn(POP) + } + } + } + } + + + class NullableField(field: Class) : Field(field) { + override val nullabilityAnnotation = "Lorg/jetbrains/annotations/NotNull;" + + constructor(name: String, field: Class) : this(field) { + if (field.isPrimitive) { + throw NullablePrimitive ( + "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") + } + + this.name = name + } + + override fun copy(name: String, field: Class) = NullableField(name, field) + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + } + } + /** * A Schema represents a desired class. */ - open class Schema(val name: String, fields: Map>, val superclass: Schema? = null, val interfaces: List> = emptyList()) { - val fields = LinkedHashMap(fields) // Fix the order up front if the user didn't. - val descriptors = fields.map { it.key to Type.getDescriptor(it.value) }.toMap() + abstract class Schema( + val name: String, + fields: Map, + val superclass: Schema? = null, + val interfaces: List> = emptyList()) { + /* Fix the order up front if the user didn't, inject the name into the field as it's + neater when iterating */ + val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) }) - fun fieldsIncludingSuperclasses(): Map> = (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) - fun descriptorsIncludingSuperclasses(): Map = (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(descriptors) + fun fieldsIncludingSuperclasses(): Map = + (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) + + fun descriptorsIncludingSuperclasses(): Map = + (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors() val jvmName: String get() = name.replace(".", "/") @@ -86,21 +181,18 @@ class ClassCarpenter { class ClassSchema( name: String, - fields: Map>, + fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() ) : Schema(name, fields, superclass, interfaces) class InterfaceSchema( name: String, - fields: Map>, + fields: Map, superclass: Schema? = null, interfaces: List> = emptyList() ) : Schema(name, fields, superclass, interfaces) - class DuplicateName : RuntimeException("An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") - class InterfaceMismatch(msg: String) : RuntimeException(msg) - private class CarpenterClassLoader : ClassLoader(Thread.currentThread().contextClassLoader) { fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) } @@ -186,9 +278,7 @@ class ClassCarpenter { } private fun ClassWriter.generateFields(schema: Schema) { - for ((name, desc) in schema.descriptors) { - visitField(ACC_PROTECTED + ACC_FINAL, name, desc, null, null).visitEnd() - } + schema.fields.forEach { it.value.generateField(this) } } private fun ClassWriter.generateToString(schema: Schema) { @@ -199,12 +289,11 @@ class ClassCarpenter { visitLdcInsn(schema.name.split('.').last()) visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", "(Ljava/lang/String;)L$toStringHelper;", false) // Call the add() methods. - for ((name, type) in schema.fieldsIncludingSuperclasses().entries) { + for ((name, field) in schema.fieldsIncludingSuperclasses().entries) { visitLdcInsn(name) visitVarInsn(ALOAD, 0) // this visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name]) - val desc = if (type.isPrimitive) schema.descriptors[name] else "Ljava/lang/Object;" - visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;$desc)L$toStringHelper;", false) + visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;${field.type})L$toStringHelper;", false) } // call toString() on the builder and return. visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()Ljava/lang/String;", false) @@ -232,12 +321,12 @@ class ClassCarpenter { private fun ClassWriter.generateGetters(schema: Schema) { for ((name, type) in schema.fields) { - val descriptor = schema.descriptors[name] - with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + descriptor, null, null)) { + with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null)) { + type.addNullabilityAnnotation(this) visitCode() visitVarInsn(ALOAD, 0) // Load 'this' - visitFieldInsn(GETFIELD, schema.jvmName, name, descriptor) - when (type) { + visitFieldInsn(GETFIELD, schema.jvmName, name, type.descriptor) + when (type.field) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, TYPE -> visitInsn(IRETURN) java.lang.Long.TYPE -> visitInsn(LRETURN) java.lang.Double.TYPE -> visitInsn(DRETURN) @@ -251,8 +340,8 @@ class ClassCarpenter { } private fun ClassWriter.generateAbstractGetters(schema: Schema) { - for ((name, _) in schema.fields) { - val descriptor = schema.descriptors[name] + for ((name, field) in schema.fields) { + val descriptor = field.descriptor val opcodes = ACC_ABSTRACT + ACC_PUBLIC with(visitMethod(opcodes, "get" + name.capitalize(), "()" + descriptor, null, null)) { // abstract method doesn't have any implementation so just end @@ -262,8 +351,18 @@ class ClassCarpenter { } private fun ClassWriter.generateConstructor(schema: Schema) { - with(visitMethod(ACC_PUBLIC, "", "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", null, null)) { + with(visitMethod( + ACC_PUBLIC, + "", + "(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V", + null, + null)) + { + var idx = 0 + schema.fields.values.forEach { it.visitParameter(this, idx++) } + visitCode() + // Calculate the super call. val superclassFields = schema.superclass?.fieldsIncludingSuperclasses() ?: emptyMap() visitVarInsn(ALOAD, 0) @@ -276,14 +375,16 @@ class ClassCarpenter { val superDesc = schema.superclass.descriptorsIncludingSuperclasses().values.joinToString("") visitMethodInsn(INVOKESPECIAL, schema.superclass.name.jvm, "", "($superDesc)V", false) } + // Assign the fields from parameters. var slot = 1 + superclassFields.size - for ((name, type) in schema.fields.entries) { - if (type.isArray) - throw UnsupportedOperationException("Array types are not implemented yet") + + for ((name, field) in schema.fields.entries) { + field.nullTest(this, slot) + visitVarInsn(ALOAD, 0) // Load 'this' onto the stack - slot += load(slot, type) // Load the contents of the parameter onto the stack. - visitFieldInsn(PUTFIELD, schema.jvmName, name, schema.descriptors[name]) + slot += load(slot, field) // Load the contents of the parameter onto the stack. + visitFieldInsn(PUTFIELD, schema.jvmName, name, field.descriptor) } visitInsn(RETURN) visitMaxs(0, 0) @@ -291,16 +392,15 @@ class ClassCarpenter { } } - // Returns how many slots the given type takes up. - private fun MethodVisitor.load(slot: Int, type: Class): Int { - when (type) { + private fun MethodVisitor.load(slot: Int, type: Field): Int { + when (type.field) { java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE, TYPE -> visitVarInsn(ILOAD, slot) java.lang.Long.TYPE -> visitVarInsn(LLOAD, slot) java.lang.Double.TYPE -> visitVarInsn(DLOAD, slot) java.lang.Float.TYPE -> visitVarInsn(FLOAD, slot) else -> visitVarInsn(ALOAD, slot) } - return when (type) { + return when (type.field) { java.lang.Long.TYPE, java.lang.Double.TYPE -> 2 else -> 1 } diff --git a/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt b/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt index 1827c83d79..8baa7fd4ae 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt @@ -1,4 +1,5 @@ -package net.corda.carpenter +package net.corda.core.serialization.carpenter + import org.junit.Test import java.lang.reflect.Field @@ -30,16 +31,19 @@ class ClassCarpenterTest { @Test fun prims() { - val clazz = cc.build(ClassCarpenter.ClassSchema("gen.Prims", mapOf( - "anIntField" to Int::class.javaPrimitiveType!!, - "aLongField" to Long::class.javaPrimitiveType!!, - "someCharField" to Char::class.javaPrimitiveType!!, - "aShortField" to Short::class.javaPrimitiveType!!, - "doubleTrouble" to Double::class.javaPrimitiveType!!, - "floatMyBoat" to Float::class.javaPrimitiveType!!, - "byteMe" to Byte::class.javaPrimitiveType!!, - "booleanField" to Boolean::class.javaPrimitiveType!! - ))) + val clazz = cc.build(ClassCarpenter.ClassSchema( + "gen.Prims", + mapOf( + "anIntField" to Int::class.javaPrimitiveType!!, + "aLongField" to Long::class.javaPrimitiveType!!, + "someCharField" to Char::class.javaPrimitiveType!!, + "aShortField" to Short::class.javaPrimitiveType!!, + "doubleTrouble" to Double::class.javaPrimitiveType!!, + "floatMyBoat" to Float::class.javaPrimitiveType!!, + "byteMe" to Byte::class.javaPrimitiveType!!, + "booleanField" to Boolean::class.javaPrimitiveType!!).mapValues { + ClassCarpenter.NonNullableField (it.value) + })) assertEquals(8, clazz.nonSyntheticFields.size) assertEquals(10, clazz.nonSyntheticMethods.size) assertEquals(8, clazz.declaredConstructors[0].parameterCount) @@ -68,7 +72,7 @@ class ClassCarpenterTest { val clazz = cc.build(ClassCarpenter.ClassSchema("gen.Person", mapOf( "age" to Int::class.javaPrimitiveType!!, "name" to String::class.java - ))) + ).mapValues { ClassCarpenter.NonNullableField (it.value) } )) val i = clazz.constructors[0].newInstance(32, "Mike") return Pair(clazz, i) } @@ -82,7 +86,7 @@ class ClassCarpenterTest { @Test fun `generated toString`() { - val (clazz, i) = genPerson() + val (_, i) = genPerson() assertEquals("Person{age=32, name=Mike}", i.toString()) } @@ -96,7 +100,7 @@ class ClassCarpenterTest { fun `can refer to each other`() { val (clazz1, i) = genPerson() val clazz2 = cc.build(ClassCarpenter.ClassSchema("gen.Referee", mapOf( - "ref" to clazz1 + "ref" to ClassCarpenter.NonNullableField (clazz1) ))) val i2 = clazz2.constructors[0].newInstance(i) assertEquals(i, (i2 as SimpleFieldAccess)["ref"]) @@ -104,8 +108,15 @@ class ClassCarpenterTest { @Test fun superclasses() { - val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("b" to String::class.java), schema1) + val schema1 = ClassCarpenter.ClassSchema( + "gen.A", + mapOf("a" to ClassCarpenter.NonNullableField (String::class.java))) + + val schema2 = ClassCarpenter.ClassSchema( + "gen.B", + mapOf("b" to ClassCarpenter.NonNullableField (String::class.java)), + schema1) + val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", "xb") as SimpleFieldAccess assertEquals("xa", i["a"]) @@ -115,8 +126,14 @@ class ClassCarpenterTest { @Test fun interfaces() { - val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("b" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) + val schema1 = ClassCarpenter.ClassSchema( + "gen.A", + mapOf("a" to ClassCarpenter.NonNullableField(String::class.java))) + + val schema2 = ClassCarpenter.ClassSchema("gen.B", + mapOf("b" to ClassCarpenter.NonNullableField(Int::class.java)), + schema1, + interfaces = listOf(DummyInterface::class.java)) val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface assertEquals("xa", i.a) @@ -125,8 +142,16 @@ class ClassCarpenterTest { @Test(expected = ClassCarpenter.InterfaceMismatch::class) fun `mismatched interface`() { - val schema1 = ClassCarpenter.ClassSchema("gen.A", mapOf("a" to String::class.java)) - val schema2 = ClassCarpenter.ClassSchema("gen.B", mapOf("c" to Int::class.java), schema1, interfaces = listOf(DummyInterface::class.java)) + val schema1 = ClassCarpenter.ClassSchema( + "gen.A", + mapOf("a" to ClassCarpenter.NonNullableField(String::class.java))) + + val schema2 = ClassCarpenter.ClassSchema( + "gen.B", + mapOf("c" to ClassCarpenter.NonNullableField(Int::class.java)), + schema1, + interfaces = listOf(DummyInterface::class.java)) + val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface assertEquals(1, i.b) @@ -134,15 +159,22 @@ class ClassCarpenterTest { @Test fun `generate interface`() { - val schema1 = ClassCarpenter.InterfaceSchema("gen.Interface", mapOf("a" to Int::class.java)) + val schema1 = ClassCarpenter.InterfaceSchema( + "gen.Interface", + mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java))) + val iface = cc.build(schema1) - assert(iface.isInterface()) + assert(iface.isInterface) assert(iface.constructors.isEmpty()) assertEquals(iface.declaredMethods.size, 1) assertEquals(iface.declaredMethods[0].name, "getA") - val schema2 = ClassCarpenter.ClassSchema("gen.Derived", mapOf("a" to Int::class.java), interfaces = listOf(iface)) + val schema2 = ClassCarpenter.ClassSchema( + "gen.Derived", + mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java)), + interfaces = listOf(iface)) + val clazz = cc.build(schema2) val testA = 42 val i = clazz.constructors[0].newInstance(testA) as SimpleFieldAccess @@ -152,16 +184,25 @@ class ClassCarpenterTest { @Test fun `generate multiple interfaces`() { - val iFace1 = ClassCarpenter.InterfaceSchema("gen.Interface1", mapOf("a" to Int::class.java, "b" to String::class.java)) - val iFace2 = ClassCarpenter.InterfaceSchema("gen.Interface2", mapOf("c" to Int::class.java, "d" to String::class.java)) + val iFace1 = ClassCarpenter.InterfaceSchema( + "gen.Interface1", + mapOf( + "a" to ClassCarpenter.NonNullableField(Int::class.java), + "b" to ClassCarpenter.NonNullableField(String::class.java))) + + val iFace2 = ClassCarpenter.InterfaceSchema( + "gen.Interface2", + mapOf( + "c" to ClassCarpenter.NonNullableField(Int::class.java), + "d" to ClassCarpenter.NonNullableField(String::class.java))) val class1 = ClassCarpenter.ClassSchema( "gen.Derived", mapOf( - "a" to Int::class.java, - "b" to String::class.java, - "c" to Int::class.java, - "d" to String::class.java), + "a" to ClassCarpenter.NonNullableField(Int::class.java), + "b" to ClassCarpenter.NonNullableField(String::class.java), + "c" to ClassCarpenter.NonNullableField(Int::class.java), + "d" to ClassCarpenter.NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace1), cc.build(iFace2))) val clazz = cc.build(class1) @@ -182,23 +223,23 @@ class ClassCarpenterTest { val iFace1 = ClassCarpenter.InterfaceSchema( "gen.Interface1", mapOf( - "a" to Int::class.java, - "b" to String::class.java)) + "a" to ClassCarpenter.NonNullableField (Int::class.java), + "b" to ClassCarpenter.NonNullableField(String::class.java))) val iFace2 = ClassCarpenter.InterfaceSchema( "gen.Interface2", mapOf( - "c" to Int::class.java, - "d" to String::class.java), + "c" to ClassCarpenter.NonNullableField(Int::class.java), + "d" to ClassCarpenter.NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace1))) val class1 = ClassCarpenter.ClassSchema( "gen.Derived", mapOf( - "a" to Int::class.java, - "b" to String::class.java, - "c" to Int::class.java, - "d" to String::class.java), + "a" to ClassCarpenter.NonNullableField(Int::class.java), + "b" to ClassCarpenter.NonNullableField(String::class.java), + "c" to ClassCarpenter.NonNullableField(Int::class.java), + "d" to ClassCarpenter.NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace2))) val clazz = cc.build(class1) @@ -213,4 +254,214 @@ class ClassCarpenterTest { assertEquals(testC, i["c"]) assertEquals(testD, i["d"]) } + + @Test(expected = java.lang.IllegalArgumentException::class) + fun `null parameter small int`() { + val className = "iEnjoySwede" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java))) + + val clazz = cc.build(schema) + + val a : Int? = null + clazz.constructors[0].newInstance(a) + } + + @Test(expected = ClassCarpenter.NullablePrimitive::class) + fun `nullable parameter small int`() { + val className = "iEnjoySwede" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NullableField (Int::class.java))) + + cc.build(schema) + } + + @Test + fun `nullable parameter integer`() { + val className = "iEnjoyWibble" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NullableField (Integer::class.java))) + + val clazz = cc.build(schema) + val a1 : Int? = null + clazz.constructors[0].newInstance(a1) + + val a2 : Int? = 10 + clazz.constructors[0].newInstance(a2) + } + + @Test + fun `non nullable parameter integer with non null`() { + val className = "iEnjoyWibble" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField (Integer::class.java))) + + val clazz = cc.build(schema) + + val a : Int? = 10 + clazz.constructors[0].newInstance(a) + } + + @Test(expected = java.lang.reflect.InvocationTargetException::class) + fun `non nullable parameter integer with null`() { + val className = "iEnjoyWibble" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField (Integer::class.java))) + + val clazz = cc.build(schema) + + val a : Int? = null + clazz.constructors[0].newInstance(a) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `int array`() { + val className = "iEnjoyPotato" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField(IntArray::class.java))) + + val clazz = cc.build(schema) + + val i = clazz.constructors[0].newInstance(intArrayOf(1, 2, 3)) as SimpleFieldAccess + + val arr = clazz.getMethod("getA").invoke(i) + + assertEquals(1, (arr as IntArray)[0]) + assertEquals(2, arr[1]) + assertEquals(3, arr[2]) + assertEquals("$className{a=[1, 2, 3]}", i.toString()) + } + + @Test(expected = java.lang.reflect.InvocationTargetException::class) + fun `nullable int array throws`() { + val className = "iEnjoySwede" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField(IntArray::class.java))) + + val clazz = cc.build(schema) + + val a : IntArray? = null + clazz.constructors[0].newInstance(a) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `integer array`() { + val className = "iEnjoyFlan" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NonNullableField(Array::class.java))) + + val clazz = cc.build(schema) + + val i = clazz.constructors[0].newInstance(arrayOf(1, 2, 3)) as SimpleFieldAccess + + val arr = clazz.getMethod("getA").invoke(i) + + assertEquals(1, (arr as Array)[0]) + assertEquals(2, arr[1]) + assertEquals(3, arr[2]) + assertEquals("$className{a=[1, 2, 3]}", i.toString()) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `int array with ints`() { + val className = "iEnjoyCrumble" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", mapOf( + "a" to Int::class.java, + "b" to IntArray::class.java, + "c" to Int::class.java).mapValues { ClassCarpenter.NonNullableField(it.value) }) + + val clazz = cc.build(schema) + + val i = clazz.constructors[0].newInstance(2, intArrayOf(4, 8), 16) as SimpleFieldAccess + + assertEquals(2, clazz.getMethod("getA").invoke(i)) + assertEquals(4, (clazz.getMethod("getB").invoke(i) as IntArray)[0]) + assertEquals(8, (clazz.getMethod("getB").invoke(i) as IntArray)[1]) + assertEquals(16, clazz.getMethod("getC").invoke(i)) + + assertEquals("$className{a=2, b=[4, 8], c=16}", i.toString()) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `multiple int arrays`() { + val className = "iEnjoyJam" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", mapOf( + "a" to IntArray::class.java, + "b" to Int::class.java, + "c" to IntArray::class.java).mapValues { ClassCarpenter.NonNullableField(it.value) }) + + val clazz = cc.build(schema) + val i = clazz.constructors[0].newInstance(intArrayOf(1, 2), 3, intArrayOf(4, 5, 6)) + + assertEquals(1, (clazz.getMethod("getA").invoke(i) as IntArray)[0]) + assertEquals(2, (clazz.getMethod("getA").invoke(i) as IntArray)[1]) + assertEquals(3, clazz.getMethod("getB").invoke(i)) + assertEquals(4, (clazz.getMethod("getC").invoke(i) as IntArray)[0]) + assertEquals(5, (clazz.getMethod("getC").invoke(i) as IntArray)[1]) + assertEquals(6, (clazz.getMethod("getC").invoke(i) as IntArray)[2]) + + assertEquals("$className{a=[1, 2], b=3, c=[4, 5, 6]}", i.toString()) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `string array`() { + val className = "iEnjoyToast" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf("a" to ClassCarpenter.NullableField(Array::class.java))) + + val clazz = cc.build(schema) + + val i = clazz.constructors[0].newInstance(arrayOf("toast", "butter", "jam")) + val arr = clazz.getMethod("getA").invoke(i) as Array + + assertEquals("toast", arr[0]) + assertEquals("butter", arr[1]) + assertEquals("jam", arr[2]) + } + + @Test + @Suppress("UNCHECKED_CAST") + fun `string arrays`() { + val className = "iEnjoyToast" + val schema = ClassCarpenter.ClassSchema( + "gen.$className", + mapOf( + "a" to Array::class.java, + "b" to String::class.java, + "c" to Array::class.java).mapValues { ClassCarpenter.NullableField (it.value) }) + + val clazz = cc.build(schema) + + val i = clazz.constructors[0].newInstance( + arrayOf("bread", "spread", "cheese"), + "and on the side", + arrayOf("some pickles", "some fries")) + + + val arr1 = clazz.getMethod("getA").invoke(i) as Array + val arr2 = clazz.getMethod("getC").invoke(i) as Array + + assertEquals("bread", arr1[0]) + assertEquals("spread", arr1[1]) + assertEquals("cheese", arr1[2]) + assertEquals("and on the side", clazz.getMethod("getB").invoke(i)) + assertEquals("some pickles", arr2[0]) + assertEquals("some fries", arr2[1]) + } }