diff --git a/experimental/kryo-hook/build.gradle b/experimental/kryo-hook/build.gradle new file mode 100644 index 0000000000..cf52f3c9bb --- /dev/null +++ b/experimental/kryo-hook/build.gradle @@ -0,0 +1,53 @@ +buildscript { + // For sharing constants between builds + Properties constants = new Properties() + file("$projectDir/../../constants.properties").withInputStream { constants.load(it) } + + ext.kotlin_version = constants.getProperty("kotlinVersion") + ext.javaassist_version = "3.12.1.GA" + + repositories { + mavenLocal() + mavenCentral() + jcenter() + } + + dependencies { + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + } +} + +repositories { + mavenLocal() + mavenCentral() + jcenter() +} + +apply plugin: 'kotlin' +apply plugin: 'kotlin-kapt' +apply plugin: 'idea' + +description 'A javaagent to allow hooking into Kryo' + +dependencies { + compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" + compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version" + compile "javassist:javassist:$javaassist_version" + compile "com.esotericsoftware:kryo:4.0.0" + compile "co.paralleluniverse:quasar-core:$quasar_version:jdk8" +} + +jar { + archiveName = "${project.name}.jar" + manifest { + attributes( + 'Premain-Class': 'net.corda.kryohook.KryoHookAgent', + 'Can-Redefine-Classes': 'true', + 'Can-Retransform-Classes': 'true', + 'Can-Set-Native-Method-Prefix': 'true', + 'Implementation-Title': "KryoHook", + 'Implementation-Version': rootProject.version + ) + } + from { configurations.compile.collect { it.isDirectory() ? it : zipTree(it) } } +} diff --git a/experimental/kryo-hook/src/main/kotlin/net/corda/kryohook/KryoHook.kt b/experimental/kryo-hook/src/main/kotlin/net/corda/kryohook/KryoHook.kt new file mode 100644 index 0000000000..8d57dc8dbc --- /dev/null +++ b/experimental/kryo-hook/src/main/kotlin/net/corda/kryohook/KryoHook.kt @@ -0,0 +1,159 @@ +package net.corda.kryohook + +import co.paralleluniverse.strands.Strand +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Output +import javassist.ClassPool +import javassist.CtClass +import java.io.ByteArrayInputStream +import java.lang.StringBuilder +import java.lang.instrument.ClassFileTransformer +import java.lang.instrument.Instrumentation +import java.security.ProtectionDomain +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +class KryoHookAgent { + companion object { + @JvmStatic + fun premain(argumentsString: String?, instrumentation: Instrumentation) { + Runtime.getRuntime().addShutdownHook(Thread { + val statsTrees = KryoHook.events.values.flatMap { + readTrees(it, 0).second + } + val builder = StringBuilder() + statsTrees.forEach { + prettyStatsTree(0, it, builder) + } + print(builder.toString()) + }) + instrumentation.addTransformer(KryoHook) + } + } +} + +fun prettyStatsTree(indent: Int, statsTree: StatsTree, builder: StringBuilder) { + when (statsTree) { + is StatsTree.Object -> { + builder.append(kotlin.CharArray(indent) { ' ' }) + builder.append(statsTree.className) + builder.append(" ") + builder.append(statsTree.size) + builder.append("\n") + for (child in statsTree.children) { + prettyStatsTree(indent + 2, child, builder) + } + } + } +} + +object KryoHook : ClassFileTransformer { + val classPool = ClassPool.getDefault() + + val hookClassName = javaClass.name + + override fun transform( + loader: ClassLoader?, + className: String, + classBeingRedefined: Class<*>?, + protectionDomain: ProtectionDomain?, + classfileBuffer: ByteArray + ): ByteArray? { + if (className.startsWith("java") || className.startsWith("javassist") || className.startsWith("kotlin")) { + return null + } + return try { + val clazz = classPool.makeClass(ByteArrayInputStream(classfileBuffer)) + instrumentClass(clazz)?.toBytecode() + } catch (throwable: Throwable) { + println("SOMETHING WENT WRONG") + throwable.printStackTrace(System.out) + null + } + } + + private fun instrumentClass(clazz: CtClass): CtClass? { + for (method in clazz.declaredBehaviors) { + if (method.name == "write") { + val parameterTypeNames = method.parameterTypes.map { it.name } + if (parameterTypeNames == listOf("com.esotericsoftware.kryo.Kryo", "com.esotericsoftware.kryo.io.Output", "java.lang.Object")) { + if (method.isEmpty) continue + println("Instrumenting ${clazz.name}") + method.insertBefore("$hookClassName.${this::writeEnter.name}($1, $2, $3);") + method.insertAfter("$hookClassName.${this::writeExit.name}($1, $2, $3);") + return clazz + } + } + } + return null + } + + val events = ConcurrentHashMap>() + val eventCount = AtomicInteger(0) + + @JvmStatic + fun writeEnter(kryo: Kryo, output: Output, obj: Any) { + events.getOrPut(Strand.currentStrand().id) { ArrayList() }.add( + StatsEvent.Enter(obj.javaClass.name, output.total()) + ) + if (eventCount.incrementAndGet() % 100 == 0) { + println("EVENT COUNT ${eventCount}") + } + } + @JvmStatic + fun writeExit(kryo: Kryo, output: Output, obj: Any) { + events.get(Strand.currentStrand().id)!!.add( + StatsEvent.Exit(obj.javaClass.name, output.total()) + ) + } +} + +sealed class StatsEvent { + data class Enter(val className: String, val offset: Long) : StatsEvent() + data class Exit(val className: String, val offset: Long) : StatsEvent() +} + +sealed class StatsTree { + data class Object( + val className: String, + val size: Long, + val children: List + ) : StatsTree() +} + + +fun readTree(events: List, index: Int): Pair { + val event = events[index] + when (event) { + is StatsEvent.Enter -> { + val (nextIndex, children) = readTrees(events, index + 1) + val exit = events[nextIndex] as StatsEvent.Exit + require(event.className == exit.className) + return Pair(nextIndex + 1, StatsTree.Object(event.className, exit.offset - event.offset, children)) + } + is StatsEvent.Exit -> { + throw IllegalStateException("Wasn't expecting Exit") + } + } +} + +fun readTrees(events: List, index: Int): Pair> { + val trees = ArrayList() + var i = index + while (true) { + val event = events.getOrNull(i) + when (event) { + is StatsEvent.Enter -> { + val (nextIndex, tree) = readTree(events, i) + trees.add(tree) + i = nextIndex + } + is StatsEvent.Exit -> { + return Pair(i, trees) + } + null -> { + return Pair(i, trees) + } + } + } +}