ENT-1906: Update DJVM to load deterministic rt.jar into sandbox. (#3973)

* Update DJVM to load deterministic rt.jar into sandbox.
* Disallow invocations of notify(), notifyAll() and wait() APIs.
* Pass entire derivedMember to MemberVisitorImpl.
* Updates after review.
* Refactor MethodBody handlers to use EmitterModule.
This commit is contained in:
Chris Rankin 2018-09-27 17:28:22 +01:00 committed by GitHub
parent e92ad538cf
commit d35a47bc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1006 additions and 810 deletions

View File

@ -3,6 +3,7 @@ plugins {
} }
apply plugin: 'net.corda.plugins.publish-utils' apply plugin: 'net.corda.plugins.publish-utils'
apply plugin: 'com.jfrog.artifactory' apply plugin: 'com.jfrog.artifactory'
apply plugin: 'idea'
description 'Corda deterministic JVM sandbox' description 'Corda deterministic JVM sandbox'
@ -11,8 +12,18 @@ ext {
asm_version = '6.1.1' asm_version = '6.1.1'
} }
repositories {
maven {
url "$artifactory_contextUrl/corda-dev"
}
}
configurations { configurations {
testCompile.extendsFrom shadow testCompile.extendsFrom shadow
jdkRt.resolutionStrategy {
// Always check the repository for a newer SNAPSHOT.
cacheChangingModulesFor 0, 'seconds'
}
} }
dependencies { dependencies {
@ -32,6 +43,7 @@ dependencies {
testCompile "junit:junit:$junit_version" testCompile "junit:junit:$junit_version"
testCompile "org.assertj:assertj-core:$assertj_version" testCompile "org.assertj:assertj-core:$assertj_version"
testCompile "org.apache.logging.log4j:log4j-slf4j-impl:$log4j_version" testCompile "org.apache.logging.log4j:log4j-slf4j-impl:$log4j_version"
jdkRt "net.corda:deterministic-rt:latest.integration"
} }
jar.enabled = false jar.enabled = false
@ -43,6 +55,10 @@ shadowJar {
} }
assemble.dependsOn shadowJar assemble.dependsOn shadowJar
tasks.withType(Test) {
systemProperty 'deterministic-rt.path', configurations.jdkRt.asPath
}
artifacts { artifacts {
publish shadowJar publish shadowJar
} }
@ -51,3 +67,10 @@ publish {
dependenciesFrom configurations.shadow dependenciesFrom configurations.shadow
name shadowJar.baseName name shadowJar.baseName
} }
idea {
module {
downloadJavadoc = true
downloadSources = true
}
}

View File

@ -63,7 +63,7 @@ abstract class ClassCommand : CommandBase() {
private lateinit var classLoader: ClassLoader private lateinit var classLoader: ClassLoader
protected var executor = SandboxExecutor<Any, Any>() protected var executor = SandboxExecutor<Any, Any>(SandboxConfiguration.DEFAULT)
private var derivedWhitelist: Whitelist = Whitelist.MINIMAL private var derivedWhitelist: Whitelist = Whitelist.MINIMAL
@ -114,7 +114,7 @@ abstract class ClassCommand : CommandBase() {
} }
private fun findDiscoverableRunnables(filters: Array<String>): List<Class<*>> { private fun findDiscoverableRunnables(filters: Array<String>): List<Class<*>> {
val classes = find<DiscoverableRunnable>() val classes = find<java.util.function.Function<*,*>>()
val applicableFilters = filters val applicableFilters = filters
.filter { !isJarFile(it) && !isFullClassName(it) } .filter { !isJarFile(it) && !isFullClassName(it) }
val filteredClasses = applicableFilters val filteredClasses = applicableFilters
@ -125,7 +125,7 @@ abstract class ClassCommand : CommandBase() {
} }
if (applicableFilters.isNotEmpty() && filteredClasses.isEmpty()) { if (applicableFilters.isNotEmpty() && filteredClasses.isEmpty()) {
throw Exception("Could not find any classes implementing ${SandboxedRunnable::class.java.simpleName} " + throw Exception("Could not find any classes implementing ${java.util.function.Function::class.java.simpleName} " +
"whose name matches '${applicableFilters.joinToString(" ")}'") "whose name matches '${applicableFilters.joinToString(" ")}'")
} }
@ -189,7 +189,7 @@ abstract class ClassCommand : CommandBase() {
profile = profile, profile = profile,
rules = if (ignoreRules) { emptyList() } else { Discovery.find() }, rules = if (ignoreRules) { emptyList() } else { Discovery.find() },
emitters = ignoreEmitters.emptyListIfTrueOtherwiseNull(), emitters = ignoreEmitters.emptyListIfTrueOtherwiseNull(),
definitionProviders = if(ignoreDefinitionProviders) { emptyList() } else { Discovery.find() }, definitionProviders = if (ignoreDefinitionProviders) { emptyList() } else { Discovery.find() },
enableTracing = !disableTracing, enableTracing = !disableTracing,
analysisConfiguration = AnalysisConfiguration( analysisConfiguration = AnalysisConfiguration(
whitelist = whitelist, whitelist = whitelist,

View File

@ -1,6 +1,5 @@
package net.corda.djvm.tools.cli package net.corda.djvm.tools.cli
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.source.ClassSource import net.corda.djvm.source.ClassSource
import picocli.CommandLine.Command import picocli.CommandLine.Command
import picocli.CommandLine.Parameters import picocli.CommandLine.Parameters
@ -20,7 +19,7 @@ class RunCommand : ClassCommand() {
var classes: Array<String> = emptyArray() var classes: Array<String> = emptyArray()
override fun processClasses(classes: List<Class<*>>) { override fun processClasses(classes: List<Class<*>>) {
val interfaceName = SandboxedRunnable::class.java.simpleName val interfaceName = java.util.function.Function::class.java.simpleName
for (clazz in classes) { for (clazz in classes) {
if (!clazz.interfaces.any { it.simpleName == interfaceName }) { if (!clazz.interfaces.any { it.simpleName == interfaceName }) {
printError("Class is not an instance of $interfaceName; ${clazz.name}") printError("Class is not an instance of $interfaceName; ${clazz.name}")

View File

@ -33,8 +33,9 @@ class WhitelistGenerateCommand : CommandBase() {
override fun validateArguments() = paths.isNotEmpty() override fun validateArguments() = paths.isNotEmpty()
override fun handleCommand(): Boolean { override fun handleCommand(): Boolean {
val entries = AnalysisConfiguration().use { configuration ->
val entries = mutableListOf<String>() val entries = mutableListOf<String>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClass(clazz: ClassRepresentation): ClassRepresentation { override fun visitClass(clazz: ClassRepresentation): ClassRepresentation {
entries.add(clazz.name) entries.add(clazz.name)
return super.visitClass(clazz) return super.visitClass(clazz)
@ -54,31 +55,31 @@ class WhitelistGenerateCommand : CommandBase() {
entries.add("${clazz.name}.${member.memberName}:${member.signature}") entries.add("${clazz.name}.${member.memberName}:${member.signature}")
} }
} }
val context = AnalysisContext.fromConfiguration(AnalysisConfiguration(), emptyList()) val context = AnalysisContext.fromConfiguration(configuration)
for (path in paths) { for (path in paths) {
ClassSource.fromPath(path).getStreamIterator().forEach { ClassSource.fromPath(path).getStreamIterator().forEach {
visitor.analyze(it, context) visitor.analyze(it, context)
} }
} }
val output = output entries
if (output != null) { }
Files.newOutputStream(output, StandardOpenOption.CREATE).use { output?.also {
GZIPOutputStream(it).use { Files.newOutputStream(it, StandardOpenOption.CREATE).use { out ->
PrintStream(it).use { GZIPOutputStream(out).use { gzip ->
it.println(""" PrintStream(gzip).use { pout ->
pout.println("""
|java/.* |java/.*
|javax/.* |javax/.*
|jdk/.* |jdk/.*
|com/sun/.*
|sun/.* |sun/.*
|--- |---
""".trimMargin().trim()) """.trimMargin().trim())
printEntries(it, entries) printEntries(pout, entries)
} }
} }
} }
} else { } ?: printEntries(System.out, entries)
printEntries(System.out, entries)
}
return true return true
} }

View File

@ -3,25 +3,20 @@ package net.corda.djvm
import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.costing.RuntimeCostSummary import net.corda.djvm.costing.RuntimeCostSummary
import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.djvm.source.ClassSource
/** /**
* The context in which a sandboxed operation is run. * The context in which a sandboxed operation is run.
* *
* @property configuration The configuration of the sandbox. * @property configuration The configuration of the sandbox.
* @property inputClasses The classes passed in for analysis.
*/ */
class SandboxRuntimeContext( class SandboxRuntimeContext(val configuration: SandboxConfiguration) {
val configuration: SandboxConfiguration,
private val inputClasses: List<ClassSource>
) {
/** /**
* The class loader to use inside the sandbox. * The class loader to use inside the sandbox.
*/ */
val classLoader: SandboxClassLoader = SandboxClassLoader( val classLoader: SandboxClassLoader = SandboxClassLoader(
configuration, configuration,
AnalysisContext.fromConfiguration(configuration.analysisConfiguration, inputClasses) AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
) )
/** /**
@ -35,7 +30,7 @@ class SandboxRuntimeContext(
fun use(action: SandboxRuntimeContext.() -> Unit) { fun use(action: SandboxRuntimeContext.() -> Unit) {
SandboxRuntimeContext.instance = this SandboxRuntimeContext.instance = this
try { try {
this.action() action(this)
} finally { } finally {
threadLocalContext.remove() threadLocalContext.remove()
} }
@ -43,9 +38,7 @@ class SandboxRuntimeContext(
companion object { companion object {
private val threadLocalContext = object : ThreadLocal<SandboxRuntimeContext?>() { private val threadLocalContext = ThreadLocal<SandboxRuntimeContext?>()
override fun initialValue(): SandboxRuntimeContext? = null
}
/** /**
* When called from within a sandbox, this returns the context for the current sandbox thread. * When called from within a sandbox, this returns the context for the current sandbox thread.

View File

@ -1,9 +1,15 @@
package net.corda.djvm.analysis package net.corda.djvm.analysis
import net.corda.djvm.code.ruleViolationError
import net.corda.djvm.code.thresholdViolationError
import net.corda.djvm.messages.Severity import net.corda.djvm.messages.Severity
import net.corda.djvm.references.ClassModule import net.corda.djvm.references.ClassModule
import net.corda.djvm.references.MemberModule import net.corda.djvm.references.MemberModule
import net.corda.djvm.source.BootstrapClassLoader
import net.corda.djvm.source.SourceClassLoader
import sandbox.net.corda.djvm.costing.RuntimeCostAccounter import sandbox.net.corda.djvm.costing.RuntimeCostAccounter
import java.io.Closeable
import java.io.IOException
import java.nio.file.Path import java.nio.file.Path
/** /**
@ -13,7 +19,8 @@ import java.nio.file.Path
* @param additionalPinnedClasses Classes that have already been declared in the sandbox namespace and that should be * @param additionalPinnedClasses Classes that have already been declared in the sandbox namespace and that should be
* made available inside the sandboxed environment. * made available inside the sandboxed environment.
* @property minimumSeverityLevel The minimum severity level to log and report. * @property minimumSeverityLevel The minimum severity level to log and report.
* @property classPath The extended class path to use for the analysis. * @param classPath The extended class path to use for the analysis.
* @param bootstrapJar The location of a jar containing the Java APIs.
* @property analyzeAnnotations Analyze annotations despite not being explicitly referenced. * @property analyzeAnnotations Analyze annotations despite not being explicitly referenced.
* @property prefixFilters Only record messages where the originating class name matches one of the provided prefixes. * @property prefixFilters Only record messages where the originating class name matches one of the provided prefixes.
* If none are provided, all messages will be reported. * If none are provided, all messages will be reported.
@ -24,32 +31,47 @@ class AnalysisConfiguration(
val whitelist: Whitelist = Whitelist.MINIMAL, val whitelist: Whitelist = Whitelist.MINIMAL,
additionalPinnedClasses: Set<String> = emptySet(), additionalPinnedClasses: Set<String> = emptySet(),
val minimumSeverityLevel: Severity = Severity.WARNING, val minimumSeverityLevel: Severity = Severity.WARNING,
val classPath: List<Path> = emptyList(), classPath: List<Path> = emptyList(),
bootstrapJar: Path? = null,
val analyzeAnnotations: Boolean = false, val analyzeAnnotations: Boolean = false,
val prefixFilters: List<String> = emptyList(), val prefixFilters: List<String> = emptyList(),
val classModule: ClassModule = ClassModule(), val classModule: ClassModule = ClassModule(),
val memberModule: MemberModule = MemberModule() val memberModule: MemberModule = MemberModule()
) { ) : Closeable {
/** /**
* Classes that have already been declared in the sandbox namespace and that should be made * Classes that have already been declared in the sandbox namespace and that should be made
* available inside the sandboxed environment. * available inside the sandboxed environment.
*/ */
val pinnedClasses: Set<String> = setOf(SANDBOXED_OBJECT, RUNTIME_COST_ACCOUNTER) + additionalPinnedClasses val pinnedClasses: Set<String> = setOf(
SANDBOXED_OBJECT,
RuntimeCostAccounter.TYPE_NAME,
ruleViolationError,
thresholdViolationError
) + additionalPinnedClasses
/** /**
* Functionality used to resolve the qualified name and relevant information about a class. * Functionality used to resolve the qualified name and relevant information about a class.
*/ */
val classResolver: ClassResolver = ClassResolver(pinnedClasses, whitelist, SANDBOX_PREFIX) val classResolver: ClassResolver = ClassResolver(pinnedClasses, whitelist, SANDBOX_PREFIX)
private val bootstrapClassLoader = bootstrapJar?.let { BootstrapClassLoader(it, classResolver) }
val supportingClassLoader = SourceClassLoader(classPath, classResolver, bootstrapClassLoader)
@Throws(IOException::class)
override fun close() {
supportingClassLoader.use {
bootstrapClassLoader?.close()
}
}
companion object { companion object {
/** /**
* The package name prefix to use for classes loaded into a sandbox. * The package name prefix to use for classes loaded into a sandbox.
*/ */
private const val SANDBOX_PREFIX: String = "sandbox/" private const val SANDBOX_PREFIX: String = "sandbox/"
private const val SANDBOXED_OBJECT = "sandbox/java/lang/Object" private const val SANDBOXED_OBJECT = SANDBOX_PREFIX + "java/lang/Object"
private const val RUNTIME_COST_ACCOUNTER = RuntimeCostAccounter.TYPE_NAME
} }
} }

View File

@ -1,10 +1,10 @@
package net.corda.djvm.analysis package net.corda.djvm.analysis
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.messages.MessageCollection import net.corda.djvm.messages.MessageCollection
import net.corda.djvm.references.ClassHierarchy import net.corda.djvm.references.ClassHierarchy
import net.corda.djvm.references.EntityReference import net.corda.djvm.references.EntityReference
import net.corda.djvm.references.ReferenceMap import net.corda.djvm.references.ReferenceMap
import net.corda.djvm.source.ClassSource
/** /**
* The context in which one or more classes are analysed. * The context in which one or more classes are analysed.
@ -13,13 +13,11 @@ import net.corda.djvm.source.ClassSource
* @property classes List of class definitions that have been analyzed. * @property classes List of class definitions that have been analyzed.
* @property references A collection of all referenced members found during analysis together with the locations from * @property references A collection of all referenced members found during analysis together with the locations from
* where each member has been accessed or invoked. * where each member has been accessed or invoked.
* @property inputClasses The classes passed in for analysis.
*/ */
class AnalysisContext private constructor( class AnalysisContext private constructor(
val messages: MessageCollection, val messages: MessageCollection,
val classes: ClassHierarchy, val classes: ClassHierarchy,
val references: ReferenceMap, val references: ReferenceMap
val inputClasses: List<ClassSource>
) { ) {
private val origins = mutableMapOf<String, MutableSet<EntityReference>>() private val origins = mutableMapOf<String, MutableSet<EntityReference>>()
@ -28,7 +26,7 @@ class AnalysisContext private constructor(
* Record a class origin in the current analysis context. * Record a class origin in the current analysis context.
*/ */
fun recordClassOrigin(name: String, origin: EntityReference) { fun recordClassOrigin(name: String, origin: EntityReference) {
origins.getOrPut(name.normalize()) { mutableSetOf() }.add(origin) origins.getOrPut(name.asPackagePath) { mutableSetOf() }.add(origin)
} }
/** /**
@ -42,20 +40,14 @@ class AnalysisContext private constructor(
/** /**
* Create a new analysis context from provided configuration. * Create a new analysis context from provided configuration.
*/ */
fun fromConfiguration(configuration: AnalysisConfiguration, classes: List<ClassSource>): AnalysisContext { fun fromConfiguration(configuration: AnalysisConfiguration): AnalysisContext {
return AnalysisContext( return AnalysisContext(
MessageCollection(configuration.minimumSeverityLevel, configuration.prefixFilters), MessageCollection(configuration.minimumSeverityLevel, configuration.prefixFilters),
ClassHierarchy(configuration.classModule, configuration.memberModule), ClassHierarchy(configuration.classModule, configuration.memberModule),
ReferenceMap(configuration.classModule), ReferenceMap(configuration.classModule)
classes
) )
} }
/**
* Local extension method for normalizing a class name.
*/
private fun String.normalize() = this.replace("/", ".")
} }
} }

View File

@ -4,30 +4,25 @@ import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.Instruction import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.* import net.corda.djvm.code.instructions.*
import net.corda.djvm.messages.Message import net.corda.djvm.messages.Message
import net.corda.djvm.references.ClassReference import net.corda.djvm.references.*
import net.corda.djvm.references.ClassRepresentation
import net.corda.djvm.references.Member
import net.corda.djvm.references.MemberReference
import net.corda.djvm.source.SourceClassLoader
import org.objectweb.asm.* import org.objectweb.asm.*
import java.io.InputStream import java.io.InputStream
/** /**
* Functionality for traversing a class and its members. * Functionality for traversing a class and its members.
* *
* @property classVisitor Class visitor to use when traversing the structure of classes.
* @property configuration The configuration to use for the analysis * @property configuration The configuration to use for the analysis
* @property classVisitor Class visitor to use when traversing the structure of classes.
*/ */
open class ClassAndMemberVisitor( open class ClassAndMemberVisitor(
private val classVisitor: ClassVisitor? = null, private val configuration: AnalysisConfiguration,
private val configuration: AnalysisConfiguration = AnalysisConfiguration() private val classVisitor: ClassVisitor?
) { ) {
/** /**
* Holds a reference to the currently used analysis context. * Holds a reference to the currently used analysis context.
*/ */
protected var analysisContext: AnalysisContext = protected var analysisContext: AnalysisContext = AnalysisContext.fromConfiguration(configuration)
AnalysisContext.fromConfiguration(configuration, emptyList())
/** /**
* Holds a link to the class currently being traversed. * Holds a link to the class currently being traversed.
@ -44,12 +39,6 @@ open class ClassAndMemberVisitor(
*/ */
private var sourceLocation = SourceLocation() private var sourceLocation = SourceLocation()
/**
* The class loader used to find classes on the extended class path.
*/
private val supportingClassLoader =
SourceClassLoader(configuration.classPath, configuration.classResolver)
/** /**
* Analyze class by using the provided qualified name of the class. * Analyze class by using the provided qualified name of the class.
*/ */
@ -63,7 +52,7 @@ open class ClassAndMemberVisitor(
* @param origin The originating class for the analysis. * @param origin The originating class for the analysis.
*/ */
fun analyze(className: String, context: AnalysisContext, origin: String? = null) { fun analyze(className: String, context: AnalysisContext, origin: String? = null) {
supportingClassLoader.classReader(className, context, origin).apply { configuration.supportingClassLoader.classReader(className, context, origin).apply {
analyze(this, context) analyze(this, context)
} }
} }
@ -167,7 +156,8 @@ open class ClassAndMemberVisitor(
} }
/** /**
* Run action with a guard that populates [messages] based on the output. * Run action with a guard that populates [AnalysisRuntimeContext.messages]
* based on the output.
*/ */
private inline fun captureExceptions(action: () -> Unit): Boolean { private inline fun captureExceptions(action: () -> Unit): Boolean {
return try { return try {
@ -229,9 +219,7 @@ open class ClassAndMemberVisitor(
ClassRepresentation(version, access, name, superClassName, interfaceNames, genericsDetails = signature ?: "").also { ClassRepresentation(version, access, name, superClassName, interfaceNames, genericsDetails = signature ?: "").also {
currentClass = it currentClass = it
currentMember = null currentMember = null
sourceLocation = SourceLocation( sourceLocation = SourceLocation(className = name)
className = name
)
} }
captureExceptions { captureExceptions {
currentClass = visitClass(currentClass!!) currentClass = visitClass(currentClass!!)
@ -251,7 +239,7 @@ open class ClassAndMemberVisitor(
override fun visitEnd() { override fun visitEnd() {
configuration.classModule configuration.classModule
.getClassReferencesFromClass(currentClass!!, configuration.analyzeAnnotations) .getClassReferencesFromClass(currentClass!!, configuration.analyzeAnnotations)
.forEach { recordTypeReference(it) } .forEach(::recordTypeReference)
captureExceptions { captureExceptions {
visitClassEnd(currentClass!!) visitClassEnd(currentClass!!)
} }
@ -306,14 +294,15 @@ open class ClassAndMemberVisitor(
configuration.memberModule.addToClass(clazz, visitedMember ?: member) configuration.memberModule.addToClass(clazz, visitedMember ?: member)
return if (processMember) { return if (processMember) {
val derivedMember = visitedMember ?: member val derivedMember = visitedMember ?: member
val targetVisitor = super.visitMethod( super.visitMethod(
derivedMember.access, derivedMember.access,
derivedMember.memberName, derivedMember.memberName,
derivedMember.signature, derivedMember.signature,
signature, signature,
derivedMember.exceptions.toTypedArray() derivedMember.exceptions.toTypedArray()
) )?.let { targetVisitor ->
MethodVisitorImpl(targetVisitor) MethodVisitorImpl(targetVisitor, derivedMember)
}
} else { } else {
null null
} }
@ -340,14 +329,15 @@ open class ClassAndMemberVisitor(
configuration.memberModule.addToClass(clazz, visitedMember ?: member) configuration.memberModule.addToClass(clazz, visitedMember ?: member)
return if (processMember) { return if (processMember) {
val derivedMember = visitedMember ?: member val derivedMember = visitedMember ?: member
val targetVisitor = super.visitField( super.visitField(
derivedMember.access, derivedMember.access,
derivedMember.memberName, derivedMember.memberName,
derivedMember.signature, derivedMember.signature,
signature, signature,
derivedMember.value derivedMember.value
) )?.let { targetVisitor ->
FieldVisitorImpl(targetVisitor) FieldVisitorImpl(targetVisitor)
}
} else { } else {
null null
} }
@ -359,7 +349,8 @@ open class ClassAndMemberVisitor(
* Visitor used to traverse and analyze a method. * Visitor used to traverse and analyze a method.
*/ */
private inner class MethodVisitorImpl( private inner class MethodVisitorImpl(
targetVisitor: MethodVisitor? targetVisitor: MethodVisitor,
private val method: Member
) : MethodVisitor(API_VERSION, targetVisitor) { ) : MethodVisitor(API_VERSION, targetVisitor) {
/** /**
@ -387,6 +378,16 @@ open class ClassAndMemberVisitor(
return super.visitAnnotation(desc, visible) return super.visitAnnotation(desc, visible)
} }
/**
* Write any new method body code, assuming the definition providers
* have provided any. This handler will not be visited if this method
* has no existing code.
*/
override fun visitCode() {
tryReplaceMethodBody()
super.visitCode()
}
/** /**
* Extract information about provided field access instruction. * Extract information about provided field access instruction.
*/ */
@ -493,6 +494,29 @@ open class ClassAndMemberVisitor(
} }
} }
/**
* Finish visiting this method, writing any new method body byte-code
* if we haven't written it already. This would (presumably) only happen
* for methods that previously had no body, e.g. native methods.
*/
override fun visitEnd() {
tryReplaceMethodBody()
super.visitEnd()
}
private fun tryReplaceMethodBody() {
if (method.body.isNotEmpty() && (mv != null)) {
EmitterModule(mv).apply {
for (body in method.body) {
body(this)
}
}
mv.visitMaxs(-1, -1)
mv.visitEnd()
mv = null
}
}
/** /**
* Helper function used to streamline the access to an instruction and to catch any related processing errors. * Helper function used to streamline the access to an instruction and to catch any related processing errors.
*/ */
@ -517,7 +541,7 @@ open class ClassAndMemberVisitor(
* Visitor used to traverse and analyze a field. * Visitor used to traverse and analyze a field.
*/ */
private inner class FieldVisitorImpl( private inner class FieldVisitorImpl(
targetVisitor: FieldVisitor? targetVisitor: FieldVisitor
) : FieldVisitor(API_VERSION, targetVisitor) { ) : FieldVisitor(API_VERSION, targetVisitor) {
/** /**

View File

@ -1,5 +1,8 @@
package net.corda.djvm.analysis package net.corda.djvm.analysis
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.code.asResourcePath
/** /**
* Functionality for resolving the class name of a sandboxable class. * Functionality for resolving the class name of a sandboxable class.
* *
@ -32,12 +35,12 @@ class ClassResolver(
*/ */
fun resolve(name: String): String { fun resolve(name: String): String {
return when { return when {
name.startsWith("[") && name.endsWith(";") -> { name.startsWith('[') && name.endsWith(';') -> {
complexArrayTypeRegex.replace(name) { complexArrayTypeRegex.replace(name) {
"${it.groupValues[1]}L${resolveName(it.groupValues[2])};" "${it.groupValues[1]}L${resolveName(it.groupValues[2])};"
} }
} }
name.startsWith("[") && !name.endsWith(";") -> name name.startsWith('[') && !name.endsWith(';') -> name
else -> resolveName(name) else -> resolveName(name)
} }
} }
@ -46,7 +49,7 @@ class ClassResolver(
* Resolve the class name from a fully qualified normalized name. * Resolve the class name from a fully qualified normalized name.
*/ */
fun resolveNormalized(name: String): String { fun resolveNormalized(name: String): String {
return resolve(name.replace('.', '/')).replace('/', '.') return resolve(name.asResourcePath).asPackagePath
} }
/** /**
@ -96,7 +99,7 @@ class ClassResolver(
* Reverse the resolution of a class name from a fully qualified normalized name. * Reverse the resolution of a class name from a fully qualified normalized name.
*/ */
fun reverseNormalized(name: String): String { fun reverseNormalized(name: String): String {
return reverse(name.replace('.', '/')).replace('/', '.') return reverse(name.asResourcePath).asPackagePath
} }
/** /**

View File

@ -117,7 +117,8 @@ open class Whitelist private constructor(
"^java/lang/Throwable(\\..*)?$".toRegex(), "^java/lang/Throwable(\\..*)?$".toRegex(),
"^java/lang/Void(\\..*)?$".toRegex(), "^java/lang/Void(\\..*)?$".toRegex(),
"^java/lang/.*Error(\\..*)?$".toRegex(), "^java/lang/.*Error(\\..*)?$".toRegex(),
"^java/lang/.*Exception(\\..*)?$".toRegex() "^java/lang/.*Exception(\\..*)?$".toRegex(),
"^java/lang/reflect/Array(\\..*)?$".toRegex()
) )
/** /**

View File

@ -20,7 +20,7 @@ class ClassMutator(
private val configuration: AnalysisConfiguration, private val configuration: AnalysisConfiguration,
private val definitionProviders: List<DefinitionProvider> = emptyList(), private val definitionProviders: List<DefinitionProvider> = emptyList(),
private val emitters: List<Emitter> = emptyList() private val emitters: List<Emitter> = emptyList()
) : ClassAndMemberVisitor(classVisitor, configuration = configuration) { ) : ClassAndMemberVisitor(configuration, classVisitor) {
/** /**
* Tracks whether any modifications have been applied to any of the processed class(es) and pertinent members. * Tracks whether any modifications have been applied to any of the processed class(es) and pertinent members.
@ -82,7 +82,8 @@ class ClassMutator(
*/ */
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) { override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
val context = EmitterContext(currentAnalysisContext(), configuration, emitter) val context = EmitterContext(currentAnalysisContext(), configuration, emitter)
Processor.processEntriesOfType<Emitter>(emitters, analysisContext.messages) { // We need to apply the tracing emitters before the non-tracing ones.
Processor.processEntriesOfType<Emitter>(emitters.sortedByDescending(Emitter::isTracer), analysisContext.messages) {
it.emit(context, instruction) it.emit(context, instruction)
} }
if (!emitter.emitDefaultInstruction || emitter.hasEmittedCustomCode) { if (!emitter.emitDefaultInstruction || emitter.hasEmittedCustomCode) {

View File

@ -1,7 +1,9 @@
package net.corda.djvm.code package net.corda.djvm.code
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import sandbox.net.corda.djvm.costing.RuntimeCostAccounter import sandbox.net.corda.djvm.costing.RuntimeCostAccounter
/** /**
@ -29,7 +31,7 @@ class EmitterModule(
/** /**
* Emit instruction for creating a new object of type [typeName]. * Emit instruction for creating a new object of type [typeName].
*/ */
fun new(typeName: String, opcode: Int = Opcodes.NEW) { fun new(typeName: String, opcode: Int = NEW) {
hasEmittedCustomCode = true hasEmittedCustomCode = true
methodVisitor.visitTypeInsn(opcode, typeName) methodVisitor.visitTypeInsn(opcode, typeName)
} }
@ -38,7 +40,7 @@ class EmitterModule(
* Emit instruction for creating a new object of type [T]. * Emit instruction for creating a new object of type [T].
*/ */
inline fun <reified T> new() { inline fun <reified T> new() {
new(T::class.java.name) new(Type.getInternalName(T::class.java))
} }
/** /**
@ -62,7 +64,7 @@ class EmitterModule(
*/ */
fun invokeStatic(owner: String, name: String, descriptor: String, isInterface: Boolean = false) { fun invokeStatic(owner: String, name: String, descriptor: String, isInterface: Boolean = false) {
hasEmittedCustomCode = true hasEmittedCustomCode = true
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, owner, name, descriptor, isInterface) methodVisitor.visitMethodInsn(INVOKESTATIC, owner, name, descriptor, isInterface)
} }
/** /**
@ -70,14 +72,14 @@ class EmitterModule(
*/ */
fun invokeSpecial(owner: String, name: String, descriptor: String, isInterface: Boolean = false) { fun invokeSpecial(owner: String, name: String, descriptor: String, isInterface: Boolean = false) {
hasEmittedCustomCode = true hasEmittedCustomCode = true
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, owner, name, descriptor, isInterface) methodVisitor.visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, isInterface)
} }
/** /**
* Emit instruction for invoking a special method on class [T], e.g. a constructor or a method on a super-type. * Emit instruction for invoking a special method on class [T], e.g. a constructor or a method on a super-type.
*/ */
inline fun <reified T> invokeSpecial(name: String, descriptor: String, isInterface: Boolean = false) { inline fun <reified T> invokeSpecial(name: String, descriptor: String, isInterface: Boolean = false) {
invokeSpecial(T::class.java.name, name, descriptor, isInterface) invokeSpecial(Type.getInternalName(T::class.java), name, descriptor, isInterface)
} }
/** /**
@ -85,7 +87,7 @@ class EmitterModule(
*/ */
fun pop() { fun pop() {
hasEmittedCustomCode = true hasEmittedCustomCode = true
methodVisitor.visitInsn(Opcodes.POP) methodVisitor.visitInsn(POP)
} }
/** /**
@ -93,19 +95,40 @@ class EmitterModule(
*/ */
fun duplicate() { fun duplicate() {
hasEmittedCustomCode = true hasEmittedCustomCode = true
methodVisitor.visitInsn(Opcodes.DUP) methodVisitor.visitInsn(DUP)
} }
/** /**
* Emit a sequence of instructions for instantiating and throwing an exception based on the provided message. * Emit a sequence of instructions for instantiating and throwing an exception based on the provided message.
*/ */
fun throwError(message: String) { fun <T : Throwable> throwException(exceptionType: Class<T>, message: String) {
hasEmittedCustomCode = true hasEmittedCustomCode = true
new<java.lang.Exception>() val exceptionName = Type.getInternalName(exceptionType)
methodVisitor.visitInsn(Opcodes.DUP) new(exceptionName)
methodVisitor.visitInsn(DUP)
methodVisitor.visitLdcInsn(message) methodVisitor.visitLdcInsn(message)
invokeSpecial<java.lang.Exception>("<init>", "(Ljava/lang/String;)V") invokeSpecial(exceptionName, "<init>", "(Ljava/lang/String;)V")
methodVisitor.visitInsn(Opcodes.ATHROW) methodVisitor.visitInsn(ATHROW)
}
inline fun <reified T : Throwable> throwException(message: String) = throwException(T::class.java, message)
/**
* Emit instruction for returning from "void" method.
*/
fun returnVoid() {
methodVisitor.visitInsn(RETURN)
hasEmittedCustomCode = true
}
/**
* Emit instructions for a new line number.
*/
fun lineNumber(line: Int) {
val label = Label()
methodVisitor.visitLabel(label)
methodVisitor.visitLineNumber(line, label)
hasEmittedCustomCode = true
} }
/** /**

View File

@ -0,0 +1,15 @@
@file:JvmName("Types")
package net.corda.djvm.code
import org.objectweb.asm.Type
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import sandbox.net.corda.djvm.rules.RuleViolationError
val ruleViolationError: String = Type.getInternalName(RuleViolationError::class.java)
val thresholdViolationError: String = Type.getInternalName(ThresholdViolationError::class.java)
/**
* Local extension method for normalizing a class name.
*/
val String.asPackagePath: String get() = this.replace('/', '.')
val String.asResourcePath: String get() = this.replace('.', '/')

View File

@ -1,6 +1,7 @@
package net.corda.djvm.costing package net.corda.djvm.costing
import net.corda.djvm.utilities.loggerFor import net.corda.djvm.utilities.loggerFor
import sandbox.net.corda.djvm.costing.ThresholdViolationError
/** /**
* Cost metric to be used in a sandbox environment. The metric has a threshold and a mechanism for reporting violations. * Cost metric to be used in a sandbox environment. The metric has a threshold and a mechanism for reporting violations.
@ -41,7 +42,7 @@ open class TypedRuntimeCost<T>(
if (thresholdPredicate(newValue)) { if (thresholdPredicate(newValue)) {
val message = errorMessage(currentThread) val message = errorMessage(currentThread)
logger.error("Threshold breached; {}", message) logger.error("Threshold breached; {}", message)
throw ThresholdViolationException(message) throw ThresholdViolationError(message)
} }
} }

View File

@ -2,6 +2,7 @@ package net.corda.djvm.execution
import net.corda.djvm.SandboxConfiguration import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.source.ClassSource import net.corda.djvm.source.ClassSource
import java.util.function.Function
/** /**
* The executor is responsible for spinning up a deterministic, sandboxed environment and launching the referenced code * The executor is responsible for spinning up a deterministic, sandboxed environment and launching the referenced code
@ -12,14 +13,14 @@ import net.corda.djvm.source.ClassSource
* @param configuration The configuration of the sandbox. * @param configuration The configuration of the sandbox.
*/ */
class DeterministicSandboxExecutor<TInput, TOutput>( class DeterministicSandboxExecutor<TInput, TOutput>(
configuration: SandboxConfiguration = SandboxConfiguration.DEFAULT configuration: SandboxConfiguration
) : SandboxExecutor<TInput, TOutput>(configuration) { ) : SandboxExecutor<TInput, TOutput>(configuration) {
/** /**
* Short-hand for running a [SandboxedRunnable] in a sandbox by its type reference. * Short-hand for running a [Function] in a sandbox by its type reference.
*/ */
inline fun <reified TRunnable : SandboxedRunnable<TInput, TOutput>> run(input: TInput): inline fun <reified TRunnable : Function<in TInput, out TOutput>> run(input: TInput):
ExecutionSummaryWithResult<TOutput?> { ExecutionSummaryWithResult<TOutput> {
return run(ClassSource.fromClassName(TRunnable::class.java.name), input) return run(ClassSource.fromClassName(TRunnable::class.java.name), input)
} }

View File

@ -1,6 +0,0 @@
package net.corda.djvm.execution
/**
* Functionality runnable by a sandbox executor, marked for discoverability.
*/
interface DiscoverableRunnable

View File

@ -1,7 +1,7 @@
package net.corda.djvm.execution package net.corda.djvm.execution
/** /**
* The execution profile of a [SandboxedRunnable] when run in a sandbox. * The execution profile of a [java.util.function.Function] when run in a sandbox.
* *
* @property allocationCostThreshold The threshold placed on allocations. * @property allocationCostThreshold The threshold placed on allocations.
* @property invocationCostThreshold The threshold placed on invocations. * @property invocationCostThreshold The threshold placed on invocations.

View File

@ -1,9 +1,9 @@
package net.corda.djvm.execution package net.corda.djvm.execution
/** /**
* The summary of the execution of a [SandboxedRunnable] in a sandbox. This class has no representation of the outcome, * The summary of the execution of a [java.util.function.Function] in a sandbox. This class has no representation of the
* and is typically used when there has been a pre-mature exit from the sandbox, for instance, if an exception was * outcome, and is typically used when there has been a pre-mature exit from the sandbox, for instance, if an exception
* thrown. * was thrown.
* *
* @property costs The costs accumulated when running the sandboxed code. * @property costs The costs accumulated when running the sandboxed code.
*/ */

View File

@ -1,7 +1,7 @@
package net.corda.djvm.execution package net.corda.djvm.execution
/** /**
* The summary of the execution of a [SandboxedRunnable] in a sandbox. * The summary of the execution of a [java.util.function.Function] in a sandbox.
* *
* @property result The outcome of the sandboxed operation. * @property result The outcome of the sandboxed operation.
* @see ExecutionSummary * @see ExecutionSummary

View File

@ -2,7 +2,6 @@ package net.corda.djvm.execution
import net.corda.djvm.SandboxConfiguration import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.SandboxRuntimeContext import net.corda.djvm.SandboxRuntimeContext
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.messages.MessageCollection import net.corda.djvm.messages.MessageCollection
import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.djvm.rewiring.SandboxClassLoadingException import net.corda.djvm.rewiring.SandboxClassLoadingException
@ -16,8 +15,7 @@ import kotlin.concurrent.thread
*/ */
class IsolatedTask( class IsolatedTask(
private val identifier: String, private val identifier: String,
private val configuration: SandboxConfiguration, private val configuration: SandboxConfiguration
private val context: AnalysisContext
) { ) {
/** /**
@ -32,12 +30,12 @@ class IsolatedTask(
var exception: Throwable? = null var exception: Throwable? = null
thread(name = threadName, isDaemon = true) { thread(name = threadName, isDaemon = true) {
logger.trace("Entering isolated runtime environment...") logger.trace("Entering isolated runtime environment...")
SandboxRuntimeContext(configuration, context.inputClasses).use { SandboxRuntimeContext(configuration).use {
output = try { output = try {
action(runnable) action(runnable)
} catch (ex: Throwable) { } catch (ex: Throwable) {
logger.error("Exception caught in isolated runtime environment", ex) logger.error("Exception caught in isolated runtime environment", ex)
exception = ex exception = (ex as? LinkageError)?.cause ?: ex
null null
} }
costs = CostSummary( costs = CostSummary(
@ -84,7 +82,7 @@ class IsolatedTask(
) )
/** /**
* The class loader to use for loading the [SandboxedRunnable] and any referenced code in [SandboxExecutor.run]. * The class loader to use for loading the [java.util.function.Function] and any referenced code in [SandboxExecutor.run].
*/ */
val classLoader: SandboxClassLoader val classLoader: SandboxClassLoader
get() = SandboxRuntimeContext.instance.classLoader get() = SandboxRuntimeContext.instance.classLoader

View File

@ -11,7 +11,6 @@ import net.corda.djvm.rewiring.SandboxClassLoadingException
import net.corda.djvm.source.ClassSource import net.corda.djvm.source.ClassSource
import net.corda.djvm.utilities.loggerFor import net.corda.djvm.utilities.loggerFor
import net.corda.djvm.validation.ReferenceValidationSummary import net.corda.djvm.validation.ReferenceValidationSummary
import net.corda.djvm.validation.ReferenceValidator
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
/** /**
@ -22,7 +21,7 @@ import java.lang.reflect.InvocationTargetException
* @property configuration The configuration of sandbox. * @property configuration The configuration of sandbox.
*/ */
open class SandboxExecutor<in TInput, out TOutput>( open class SandboxExecutor<in TInput, out TOutput>(
protected val configuration: SandboxConfiguration = SandboxConfiguration.DEFAULT protected val configuration: SandboxConfiguration
) { ) {
private val classModule = configuration.analysisConfiguration.classModule private val classModule = configuration.analysisConfiguration.classModule
@ -32,12 +31,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
private val whitelist = configuration.analysisConfiguration.whitelist private val whitelist = configuration.analysisConfiguration.whitelist
/** /**
* Module used to validate all traversable references before instantiating and executing a [SandboxedRunnable]. * Executes a [java.util.function.Function] implementation.
*/
private val referenceValidator = ReferenceValidator(configuration.analysisConfiguration)
/**
* Executes a [SandboxedRunnable] implementation.
* *
* @param runnableClass The entry point of the sandboxed code to run. * @param runnableClass The entry point of the sandboxed code to run.
* @param input The input to provide to the sandboxed environment. * @param input The input to provide to the sandboxed environment.
@ -50,7 +44,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
open fun run( open fun run(
runnableClass: ClassSource, runnableClass: ClassSource,
input: TInput input: TInput
): ExecutionSummaryWithResult<TOutput?> { ): ExecutionSummaryWithResult<TOutput> {
// 1. We first do a breath first traversal of the class hierarchy, starting from the requested class. // 1. We first do a breath first traversal of the class hierarchy, starting from the requested class.
// The branching is defined by class references from referencesFromLocation. // The branching is defined by class references from referencesFromLocation.
// 2. For each class we run validation against defined rules. // 2. For each class we run validation against defined rules.
@ -63,22 +57,22 @@ open class SandboxExecutor<in TInput, out TOutput>(
// 6. For execution, we then load the top-level class, implementing the SandboxedRunnable interface, again and // 6. For execution, we then load the top-level class, implementing the SandboxedRunnable interface, again and
// and consequently hit the cache. Once loaded, we can execute the code on the spawned thread, i.e., in an // and consequently hit the cache. Once loaded, we can execute the code on the spawned thread, i.e., in an
// isolated environment. // isolated environment.
logger.trace("Executing {} with input {}...", runnableClass, input) logger.debug("Executing {} with input {}...", runnableClass, input)
// TODO Class sources can be analyzed in parallel, although this require making the analysis context thread-safe // TODO Class sources can be analyzed in parallel, although this require making the analysis context thread-safe
// To do so, one could start by batching the first X classes from the class sources and analyse each one in // To do so, one could start by batching the first X classes from the class sources and analyse each one in
// parallel, caching any intermediate state and subsequently process enqueued sources in parallel batches as well. // parallel, caching any intermediate state and subsequently process enqueued sources in parallel batches as well.
// Note that this would require some rework of the [IsolatedTask] and the class loader to bypass the limitation // Note that this would require some rework of the [IsolatedTask] and the class loader to bypass the limitation
// of caching and state preserved in thread-local contexts. // of caching and state preserved in thread-local contexts.
val classSources = listOf(runnableClass) val classSources = listOf(runnableClass)
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, classSources) val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask(runnableClass.qualifiedClassName, configuration, context).run { val result = IsolatedTask(runnableClass.qualifiedClassName, configuration).run {
validate(context, classLoader, classSources) validate(context, classLoader, classSources)
val loadedClass = classLoader.loadClassAndBytes(runnableClass, context) val loadedClass = classLoader.loadClassAndBytes(runnableClass, context)
val instance = loadedClass.type.newInstance() val instance = loadedClass.type.newInstance()
val method = loadedClass.type.getMethod("run", Any::class.java) val method = loadedClass.type.getMethod("apply", Any::class.java)
try { try {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
method.invoke(instance, input) as? TOutput? method.invoke(instance, input) as? TOutput
} catch (ex: InvocationTargetException) { } catch (ex: InvocationTargetException) {
throw ex.targetException throw ex.targetException
} }
@ -105,8 +99,8 @@ open class SandboxExecutor<in TInput, out TOutput>(
* @return A [LoadedClass] with the class' byte code, type and name. * @return A [LoadedClass] with the class' byte code, type and name.
*/ */
fun load(classSource: ClassSource): LoadedClass { fun load(classSource: ClassSource): LoadedClass {
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, listOf(classSource)) val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask("LoadClass", configuration, context).run { val result = IsolatedTask("LoadClass", configuration).run {
classLoader.loadClassAndBytes(classSource, context) classLoader.loadClassAndBytes(classSource, context)
} }
return result.output ?: throw ClassNotFoundException(classSource.qualifiedClassName) return result.output ?: throw ClassNotFoundException(classSource.qualifiedClassName)
@ -125,8 +119,8 @@ open class SandboxExecutor<in TInput, out TOutput>(
@Throws(SandboxClassLoadingException::class) @Throws(SandboxClassLoadingException::class)
fun validate(vararg classSources: ClassSource): ReferenceValidationSummary { fun validate(vararg classSources: ClassSource): ReferenceValidationSummary {
logger.trace("Validating {}...", classSources) logger.trace("Validating {}...", classSources)
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, classSources.toList()) val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask("Validation", configuration, context).run { val result = IsolatedTask("Validation", configuration).run {
validate(context, classLoader, classSources.toList()) validate(context, classLoader, classSources.toList())
} }
logger.trace("Validation of {} resulted in {}", classSources, result) logger.trace("Validation of {} resulted in {}", classSources, result)
@ -172,10 +166,6 @@ open class SandboxExecutor<in TInput, out TOutput>(
} }
failOnReportedErrorsInContext(context) failOnReportedErrorsInContext(context)
// Validate all references in class hierarchy before proceeding.
referenceValidator.validate(context, classLoader.analyzer)
failOnReportedErrorsInContext(context)
return ReferenceValidationSummary(context.classes, context.messages, context.classOrigins) return ReferenceValidationSummary(context.classes, context.messages, context.classOrigins)
} }
@ -185,7 +175,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
private inline fun processClassQueue( private inline fun processClassQueue(
vararg elements: ClassSource, action: QueueProcessor<ClassSource>.(ClassSource, String) -> Unit vararg elements: ClassSource, action: QueueProcessor<ClassSource>.(ClassSource, String) -> Unit
) { ) {
QueueProcessor({ it.qualifiedClassName }, *elements).process { classSource -> QueueProcessor(ClassSource::qualifiedClassName, *elements).process { classSource ->
val className = classResolver.reverse(classModule.getBinaryClassName(classSource.qualifiedClassName)) val className = classResolver.reverse(classModule.getBinaryClassName(classSource.qualifiedClassName))
if (!whitelist.matches(className)) { if (!whitelist.matches(className)) {
action(classSource, className) action(classSource, className)

View File

@ -1,19 +0,0 @@
package net.corda.djvm.execution
/**
* Functionality runnable by a sandbox executor.
*/
interface SandboxedRunnable<in TInput, out TOutput> : DiscoverableRunnable {
/**
* The entry point of the sandboxed functionality to be run.
*
* @param input The input to pass in to the entry point.
*
* @returns The output to pass back to the caller after the sandboxed code has finished running.
* @throws Exception The function can throw an exception, in which case the exception gets passed to the caller.
*/
@Throws(Exception::class)
fun run(input: TInput): TOutput?
}

View File

@ -53,7 +53,7 @@ class MemberFormatter(
* Check whether or not a signature is for a method. * Check whether or not a signature is for a method.
*/ */
fun isMethod(abbreviatedSignature: String): Boolean { fun isMethod(abbreviatedSignature: String): Boolean {
return abbreviatedSignature.startsWith("(") return abbreviatedSignature.startsWith('(')
} }
/** /**

View File

@ -82,8 +82,8 @@ class ClassHierarchy(
return findAncestors(get(className)).plus(get(OBJECT_NAME)) return findAncestors(get(className)).plus(get(OBJECT_NAME))
.asSequence() .asSequence()
.filterNotNull() .filterNotNull()
.map { memberModule.getFromClass(it, memberName, signature) } .mapNotNull { memberModule.getFromClass(it, memberName, signature) }
.firstOrNull { it != null } .firstOrNull()
.apply { .apply {
logger.trace("Getting rooted member for {}.{}:{} yields {}", className, memberName, signature, this) logger.trace("Getting rooted member for {}.{}:{} yields {}", className, memberName, signature, this)
} }

View File

@ -1,5 +1,8 @@
package net.corda.djvm.references package net.corda.djvm.references
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.code.asResourcePath
/** /**
* Class-specific functionality. * Class-specific functionality.
*/ */
@ -42,14 +45,12 @@ class ClassModule : AnnotationModule() {
/** /**
* Get the binary version of a class name. * Get the binary version of a class name.
*/ */
fun getBinaryClassName(name: String) = fun getBinaryClassName(name: String) = normalizeClassName(name).asResourcePath
normalizeClassName(name).replace('.', '/')
/** /**
* Get the formatted version of a class name. * Get the formatted version of a class name.
*/ */
fun getFormattedClassName(name: String) = fun getFormattedClassName(name: String) = normalizeClassName(name).asPackagePath
normalizeClassName(name).replace('/', '.')
/** /**
* Get the short name of a class. * Get the short name of a class.

View File

@ -1,5 +1,13 @@
package net.corda.djvm.references package net.corda.djvm.references
import net.corda.djvm.code.EmitterModule
/**
* Alias for a handler which will replace an entire
* method body with a block of byte-code.
*/
typealias MethodBody = (EmitterModule) -> Unit
/** /**
* Representation of a class member. * Representation of a class member.
* *
@ -11,6 +19,7 @@ package net.corda.djvm.references
* @property annotations The names of the annotations the member is attributed. * @property annotations The names of the annotations the member is attributed.
* @property exceptions The names of the exceptions that the member can throw. * @property exceptions The names of the exceptions that the member can throw.
* @property value The default value of a field. * @property value The default value of a field.
* @property body One or more handlers to replace the method body with new byte-code.
*/ */
data class Member( data class Member(
override val access: Int, override val access: Int,
@ -20,5 +29,6 @@ data class Member(
val genericsDetails: String, val genericsDetails: String,
val annotations: MutableSet<String> = mutableSetOf(), val annotations: MutableSet<String> = mutableSetOf(),
val exceptions: MutableSet<String> = mutableSetOf(), val exceptions: MutableSet<String> = mutableSetOf(),
val value: Any? = null val value: Any? = null,
val body: List<MethodBody> = emptyList()
) : MemberInformation, EntityWithAccessFlag ) : MemberInformation, EntityWithAccessFlag

View File

@ -33,14 +33,14 @@ class MemberModule : AnnotationModule() {
* Check if member is a field. * Check if member is a field.
*/ */
fun isField(member: MemberInformation): Boolean { fun isField(member: MemberInformation): Boolean {
return !member.signature.startsWith("(") return !member.signature.startsWith('(')
} }
/** /**
* Check if member is a method. * Check if member is a method.
*/ */
fun isMethod(member: MemberInformation): Boolean { fun isMethod(member: MemberInformation): Boolean {
return member.signature.startsWith("(") return member.signature.startsWith('(')
} }
/** /**

View File

@ -16,7 +16,11 @@ class ReferenceMap(
private val referencesPerLocation: MutableMap<String, MutableSet<ReferenceWithLocation>> = hashMapOf() private val referencesPerLocation: MutableMap<String, MutableSet<ReferenceWithLocation>> = hashMapOf()
private var numberOfReferences = 0 /**
* The number of references in the map.
*/
var numberOfReferences = 0
private set
/** /**
* Add source location association to a target member. * Add source location association to a target member.
@ -50,12 +54,6 @@ class ReferenceMap(
return referencesPerLocation.getOrElse(key(className, memberName, signature)) { emptySet() } return referencesPerLocation.getOrElse(key(className, memberName, signature)) { emptySet() }
} }
/**
* The number of member references in the map.
*/
val size: Int
get() = numberOfReferences
/** /**
* Get iterator for all the references in the map. * Get iterator for all the references in the map.
*/ */

View File

@ -27,7 +27,7 @@ open class ClassRewriter(
* @param context The context in which the class is being analyzed and processed. * @param context The context in which the class is being analyzed and processed.
*/ */
fun rewrite(reader: ClassReader, context: AnalysisContext): ByteCode { fun rewrite(reader: ClassReader, context: AnalysisContext): ByteCode {
logger.trace("Rewriting class {}...", reader.className) logger.debug("Rewriting class {}...", reader.className)
val writer = SandboxClassWriter(reader, classLoader) val writer = SandboxClassWriter(reader, classLoader)
val classRemapper = ClassRemapper(writer, remapper) val classRemapper = ClassRemapper(writer, remapper)
val visitor = ClassMutator( val visitor = ClassMutator(

View File

@ -3,29 +3,31 @@ package net.corda.djvm.rewiring
import net.corda.djvm.SandboxConfiguration import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassAndMemberVisitor import net.corda.djvm.analysis.ClassAndMemberVisitor
import net.corda.djvm.code.asResourcePath
import net.corda.djvm.references.ClassReference import net.corda.djvm.references.ClassReference
import net.corda.djvm.source.ClassSource import net.corda.djvm.source.ClassSource
import net.corda.djvm.source.SourceClassLoader
import net.corda.djvm.utilities.loggerFor import net.corda.djvm.utilities.loggerFor
import net.corda.djvm.validation.RuleValidator import net.corda.djvm.validation.RuleValidator
/** /**
* Class loader that enables registration of rewired classes. * Class loader that enables registration of rewired classes.
* *
* @property configuration The configuration to use for the sandbox. * @param configuration The configuration to use for the sandbox.
* @property context The context in which analysis and processing is performed. * @property context The context in which analysis and processing is performed.
*/ */
class SandboxClassLoader( class SandboxClassLoader(
val configuration: SandboxConfiguration, configuration: SandboxConfiguration,
val context: AnalysisContext private val context: AnalysisContext
) : ClassLoader() { ) : ClassLoader(null) {
private val analysisConfiguration = configuration.analysisConfiguration
/** /**
* The instance used to validate that any loaded class complies with the specified rules. * The instance used to validate that any loaded class complies with the specified rules.
*/ */
private val ruleValidator: RuleValidator = RuleValidator( private val ruleValidator: RuleValidator = RuleValidator(
rules = configuration.rules, rules = configuration.rules,
configuration = configuration.analysisConfiguration configuration = analysisConfiguration
) )
/** /**
@ -37,12 +39,12 @@ class SandboxClassLoader(
/** /**
* Set of classes that should be left untouched due to pinning. * Set of classes that should be left untouched due to pinning.
*/ */
private val pinnedClasses = configuration.analysisConfiguration.pinnedClasses private val pinnedClasses = analysisConfiguration.pinnedClasses
/** /**
* Set of classes that should be left untouched due to whitelisting. * Set of classes that should be left untouched due to whitelisting.
*/ */
private val whitelistedClasses = configuration.analysisConfiguration.whitelist private val whitelistedClasses = analysisConfiguration.whitelist
/** /**
* Cache of loaded classes. * Cache of loaded classes.
@ -52,10 +54,7 @@ class SandboxClassLoader(
/** /**
* The class loader used to find classes on the extended class path. * The class loader used to find classes on the extended class path.
*/ */
private val supportingClassLoader = SourceClassLoader( private val supportingClassLoader = analysisConfiguration.supportingClassLoader
configuration.analysisConfiguration.classPath,
configuration.analysisConfiguration.classResolver
)
/** /**
* The re-writer to use for registered classes. * The re-writer to use for registered classes.
@ -83,9 +82,9 @@ class SandboxClassLoader(
* @return The resulting <tt>Class</tt> object and its byte code representation. * @return The resulting <tt>Class</tt> object and its byte code representation.
*/ */
fun loadClassAndBytes(source: ClassSource, context: AnalysisContext): LoadedClass { fun loadClassAndBytes(source: ClassSource, context: AnalysisContext): LoadedClass {
logger.trace("Loading class {}, origin={}...", source.qualifiedClassName, source.origin) logger.debug("Loading class {}, origin={}...", source.qualifiedClassName, source.origin)
val name = configuration.analysisConfiguration.classResolver.reverseNormalized(source.qualifiedClassName) val name = analysisConfiguration.classResolver.reverseNormalized(source.qualifiedClassName)
val resolvedName = configuration.analysisConfiguration.classResolver.resolveNormalized(name) val resolvedName = analysisConfiguration.classResolver.resolveNormalized(name)
// Check if the class has already been loaded. // Check if the class has already been loaded.
val loadedClass = loadedClasses[name] val loadedClass = loadedClasses[name]
@ -99,14 +98,14 @@ class SandboxClassLoader(
// Analyse the class if not matching the whitelist. // Analyse the class if not matching the whitelist.
val readClassName = reader.className val readClassName = reader.className
if (!configuration.analysisConfiguration.whitelist.matches(readClassName)) { if (!analysisConfiguration.whitelist.matches(readClassName)) {
logger.trace("Class {} does not match with the whitelist", source.qualifiedClassName) logger.trace("Class {} does not match with the whitelist", source.qualifiedClassName)
logger.trace("Analyzing class {}...", source.qualifiedClassName) logger.trace("Analyzing class {}...", source.qualifiedClassName)
analyzer.analyze(reader, context) analyzer.analyze(reader, context)
} }
// Check if the class should be left untouched. // Check if the class should be left untouched.
val qualifiedName = name.replace('.', '/') val qualifiedName = name.asResourcePath
if (qualifiedName in pinnedClasses) { if (qualifiedName in pinnedClasses) {
logger.trace("Class {} is marked as pinned", source.qualifiedClassName) logger.trace("Class {} is marked as pinned", source.qualifiedClassName)
val pinnedClasses = LoadedClass( val pinnedClasses = LoadedClass(
@ -146,7 +145,7 @@ class SandboxClassLoader(
context.recordClassOrigin(name, ClassReference(source.origin)) context.recordClassOrigin(name, ClassReference(source.origin))
} }
logger.trace("Loaded class {}, bytes={}, isModified={}", logger.debug("Loaded class {}, bytes={}, isModified={}",
source.qualifiedClassName, byteCode.bytes.size, byteCode.isModified) source.qualifiedClassName, byteCode.bytes.size, byteCode.isModified)
return classWithByteCode return classWithByteCode

View File

@ -1,5 +1,6 @@
package net.corda.djvm.rewiring package net.corda.djvm.rewiring
import net.corda.djvm.code.asPackagePath
import org.objectweb.asm.ClassReader import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassWriter import org.objectweb.asm.ClassWriter
import org.objectweb.asm.ClassWriter.COMPUTE_FRAMES import org.objectweb.asm.ClassWriter.COMPUTE_FRAMES
@ -35,12 +36,12 @@ open class SandboxClassWriter(
type2 == OBJECT_NAME -> return type2 type2 == OBJECT_NAME -> return type2
} }
val class1 = try { val class1 = try {
classLoader.loadClass(type1.replace('/', '.')) classLoader.loadClass(type1.asPackagePath)
} catch (exception: Exception) { } catch (exception: Exception) {
throw TypeNotPresentException(type1, exception) throw TypeNotPresentException(type1, exception)
} }
val class2 = try { val class2 = try {
classLoader.loadClass(type2.replace('/', '.')) classLoader.loadClass(type2.asPackagePath)
} catch (exception: Exception) { } catch (exception: Exception) {
throw TypeNotPresentException(type2, exception) throw TypeNotPresentException(type2, exception)
} }

View File

@ -18,7 +18,7 @@ abstract class ClassRule : Rule {
*/ */
abstract fun validate(context: RuleContext, clazz: ClassRepresentation) abstract fun validate(context: RuleContext, clazz: ClassRepresentation)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) { final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class itself. // Only run validation step if applied to the class itself.
if (clazz != null && member == null && instruction == null) { if (clazz != null && member == null && instruction == null) {
validate(context, clazz) validate(context, clazz)

View File

@ -18,7 +18,7 @@ abstract class InstructionRule : Rule {
*/ */
abstract fun validate(context: RuleContext, instruction: Instruction) abstract fun validate(context: RuleContext, instruction: Instruction)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) { final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class member itself. // Only run validation step if applied to the class member itself.
if (clazz != null && member != null && instruction != null) { if (clazz != null && member != null && instruction != null) {
validate(context, instruction) validate(context, instruction)

View File

@ -18,7 +18,7 @@ abstract class MemberRule : Rule {
*/ */
abstract fun validate(context: RuleContext, member: Member) abstract fun validate(context: RuleContext, member: Member)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) { final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class member itself. // Only run validation step if applied to the class member itself.
if (clazz != null && member != null && instruction == null) { if (clazz != null && member != null && instruction == null) {
validate(context, member) validate(context, member)

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.Instruction.Companion.OP_BREAKPOINT
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for invalid breakpoint instructions.
*/
class DisallowBreakpoints : InstructionRule() {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
fail("Disallowed breakpoint in method") given (instruction.operation == OP_BREAKPOINT)
}
}

View File

@ -1,36 +1,16 @@
package net.corda.djvm.rules.implementation package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter import net.corda.djvm.code.*
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.CodeLabel import net.corda.djvm.code.instructions.CodeLabel
import net.corda.djvm.code.instructions.TryCatchBlock import net.corda.djvm.code.instructions.TryCatchBlock
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
import org.objectweb.asm.Label import org.objectweb.asm.Label
import sandbox.net.corda.djvm.costing.ThresholdViolationError
/** /**
* Rule that checks for attempted catches of [ThreadDeath], [ThresholdViolationException], [StackOverflowError], * Rule that checks for attempted catches of [ThreadDeath], [ThresholdViolationError],
* [OutOfMemoryError], [Error] or [Throwable]. * [StackOverflowError], [OutOfMemoryError], [Error] or [Throwable].
*/ */
class DisallowCatchingBlacklistedExceptions : InstructionRule(), Emitter { class DisallowCatchingBlacklistedExceptions : Emitter {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
if (instruction is TryCatchBlock) {
val typeName = context.classModule.getFormattedClassName(instruction.typeName)
warn("Injected runtime check for catch-block for type $typeName") given
(instruction.typeName in disallowedExceptionTypes)
fail("Disallowed catch of ThreadDeath exception") given
(instruction.typeName == threadDeathException)
fail("Disallowed catch of stack overflow exception") given
(instruction.typeName == stackOverflowException)
fail("Disallowed catch of out of memory exception") given
(instruction.typeName == outOfMemoryException)
fail("Disallowed catch of threshold violation exception") given
(instruction.typeName.endsWith(ThresholdViolationException::class.java.simpleName))
}
}
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
if (instruction is TryCatchBlock && instruction.typeName in disallowedExceptionTypes) { if (instruction is TryCatchBlock && instruction.typeName in disallowedExceptionTypes) {
@ -46,13 +26,27 @@ class DisallowCatchingBlacklistedExceptions : InstructionRule(), Emitter {
private fun isExceptionHandler(label: Label) = label in handlers private fun isExceptionHandler(label: Label) = label in handlers
companion object { companion object {
private const val threadDeathException = "java/lang/ThreadDeath"
private const val stackOverflowException = "java/lang/StackOverflowError"
private const val outOfMemoryException = "java/lang/OutOfMemoryError"
// Any of [ThreadDeath]'s throwable super-classes need explicit checking.
private val disallowedExceptionTypes = setOf( private val disallowedExceptionTypes = setOf(
ruleViolationError,
thresholdViolationError,
/**
* These errors indicate that the JVM is failing,
* so don't allow these to be caught either.
*/
"java/lang/StackOverflowError",
"java/lang/OutOfMemoryError",
/**
* These are immediate super-classes for our explicit errors.
*/
"java/lang/VirtualMachineError",
"java/lang/ThreadDeath",
/**
* Any of [ThreadDeath] and [VirtualMachineError]'s throwable
* super-classes also need explicit checking.
*/
"java/lang/Throwable", "java/lang/Throwable",
"java/lang/Error" "java/lang/Error"
) )

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.references.Member
import net.corda.djvm.rules.MemberRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for invalid use of finalizers.
*/
class DisallowFinalizerMethods : MemberRule() {
override fun validate(context: RuleContext, member: Member) = context.validate {
fail("Disallowed finalizer method") given ("${member.memberName}:${member.signature}" == "finalize:()V")
// TODO Make this rule simply erase the finalize() method and continue execution.
}
}

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.references.Member
import net.corda.djvm.rules.MemberRule
import net.corda.djvm.validation.RuleContext
import java.lang.reflect.Modifier
/**
* Rule that checks for invalid use of native methods.
*/
class DisallowNativeMethods : MemberRule() {
override fun validate(context: RuleContext, member: Member) = context.validate {
fail("Disallowed native method") given Modifier.isNative(member.access)
}
}

View File

@ -0,0 +1,42 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.MemberAccessInstruction
import net.corda.djvm.formatting.MemberFormatter
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
/**
* Some non-deterministic APIs belong to pinned classes and so cannot be stubbed out.
* Replace their invocations with exceptions instead.
*/
class DisallowNonDeterministicMethods : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
if (instruction is MemberAccessInstruction && isForbidden(instruction)) {
when (instruction.operation) {
INVOKEVIRTUAL -> {
throwException<RuleViolationError>("Disallowed reference to API; ${memberFormatter.format(instruction.member)}")
preventDefault()
}
}
}
}
private fun isClassReflection(instruction: MemberAccessInstruction): Boolean =
(instruction.owner == "java/lang/Class") && (
((instruction.memberName == "newInstance" && instruction.signature == "()Ljava/lang/Object;")
|| instruction.signature.contains("Ljava/lang/reflect/"))
)
private fun isObjectMonitor(instruction: MemberAccessInstruction): Boolean =
(instruction.signature == "()V" && (instruction.memberName == "notify" || instruction.memberName == "notifyAll" || instruction.memberName == "wait"))
|| (instruction.memberName == "wait" && (instruction.signature == "(J)V" || instruction.signature == "(JI)V"))
private fun isForbidden(instruction: MemberAccessInstruction): Boolean
= instruction.isMethod && (isClassReflection(instruction) || isObjectMonitor(instruction))
private val memberFormatter = MemberFormatter()
}

View File

@ -1,30 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.MemberAccessInstruction
import net.corda.djvm.formatting.MemberFormatter
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for illegal references to reflection APIs.
*/
class DisallowReflection : InstructionRule() {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
// TODO Enable controlled use of reflection APIs
if (instruction is MemberAccessInstruction) {
invalidReflectionUsage(instruction) given
("java/lang/Class" in instruction.owner && instruction.memberName == "newInstance")
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("java/lang/reflect/"))
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("java/lang/invoke/"))
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("sun/"))
}
}
private fun RuleContext.invalidReflectionUsage(instruction: MemberAccessInstruction) =
this.fail("Disallowed reference to reflection API; ${memberFormatter.format(instruction.member)}")
private val memberFormatter = MemberFormatter()
}

View File

@ -0,0 +1,19 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.Instruction.Companion.OP_BREAKPOINT
/**
* Rule that deletes invalid breakpoint instructions.
*/
class IgnoreBreakpoints : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
when (instruction.operation) {
OP_BREAKPOINT -> preventDefault()
}
}
}

View File

@ -3,20 +3,13 @@ package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction import net.corda.djvm.code.Instruction
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
/** /**
* Rule that warns about the use of synchronized code blocks. This class also exposes an emitter that rewrites pertinent * An emitter that rewrites monitoring instructions to [POP]s, as these replacements will remove
* monitoring instructions to [POP]'s, as these replacements will remove the object references that [MONITORENTER] and * the object references that [MONITORENTER] and [MONITOREXIT] anticipate to be on the stack.
* [MONITOREXIT] anticipate to be on the stack.
*/ */
class IgnoreSynchronizedBlocks : InstructionRule(), Emitter { class IgnoreSynchronizedBlocks : Emitter {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
inform("Stripped monitoring instruction") given (instruction.operation in setOf(MONITORENTER, MONITOREXIT))
}
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit { override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
when (instruction.operation) { when (instruction.operation) {

View File

@ -0,0 +1,35 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import java.lang.reflect.Modifier
/**
* Rule that replaces a finalize() method with a simple stub.
*/
class StubOutFinalizerMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member) = when {
/**
* Discard any other method body and replace with stub that just returns.
* Other [MemberDefinitionProvider]s are expected to append to this list
* and not replace its contents!
*/
isFinalizer(member) -> member.copy(body = listOf(::writeMethodBody))
else -> member
}
private fun writeMethodBody(emitter: EmitterModule): Unit = with(emitter) {
returnVoid()
}
/**
* No need to rewrite [Object.finalize] or [Enum.finalize]; ignore these.
*/
private fun isFinalizer(member: Member): Boolean
= member.memberName == "finalize" && member.signature == "()V"
&& !member.className.startsWith("java/lang/")
&& !Modifier.isAbstract(member.access)
}

View File

@ -0,0 +1,36 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
import java.lang.reflect.Modifier
/**
* Rule that replaces a native method with a stub that throws an exception.
*/
class StubOutNativeMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member) = when {
isNative(member) -> member.copy(
access = member.access and ACC_NATIVE.inv(),
body = member.body + if (isForStubbing(member)) ::writeStubMethodBody else ::writeExceptionMethodBody
)
else -> member
}
private fun writeExceptionMethodBody(emitter: EmitterModule): Unit = with(emitter) {
lineNumber(0)
throwException(RuleViolationError::class.java, "Native method has been deleted")
}
private fun writeStubMethodBody(emitter: EmitterModule): Unit = with(emitter) {
returnVoid()
}
private fun isForStubbing(member: Member): Boolean = member.signature == "()V" && member.memberName == "registerNatives"
private fun isNative(member: Member): Boolean = Modifier.isNative(member.access)
}

View File

@ -0,0 +1,35 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
/**
* Replace reflection APIs with stubs that throw exceptions. Only for unpinned classes.
*/
class StubOutReflectionMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member): Member = when {
isConcreteApi(member) && isReflection(member) -> member.copy(body = member.body + ::writeMethodBody)
else -> member
}
private fun writeMethodBody(emitter: EmitterModule): Unit = with(emitter) {
lineNumber(0)
throwException(RuleViolationError::class.java, "Disallowed reference to reflection API")
}
// The method must be public and with a Java implementation.
private fun isConcreteApi(member: Member): Boolean = member.access and (ACC_PUBLIC or ACC_ABSTRACT or ACC_NATIVE) == ACC_PUBLIC
private fun isReflection(member: Member): Boolean {
return member.className.startsWith("java/lang/reflect/")
|| member.className.startsWith("java/lang/invoke/")
|| member.className.startsWith("sun/reflect/")
|| member.className == "sun/misc/Unsafe"
|| member.className == "sun/misc/VM"
}
}

View File

@ -3,6 +3,7 @@ package net.corda.djvm.source
import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassResolver import net.corda.djvm.analysis.ClassResolver
import net.corda.djvm.analysis.SourceLocation import net.corda.djvm.analysis.SourceLocation
import net.corda.djvm.code.asResourcePath
import net.corda.djvm.messages.Message import net.corda.djvm.messages.Message
import net.corda.djvm.messages.Severity import net.corda.djvm.messages.Severity
import net.corda.djvm.rewiring.SandboxClassLoadingException import net.corda.djvm.rewiring.SandboxClassLoadingException
@ -17,18 +18,11 @@ import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import kotlin.streams.toList import kotlin.streams.toList
/** abstract class AbstractSourceClassLoader(
* Customizable class loader that allows the user to explicitly specify additional JARs and directories to scan.
*
* @param paths The directories and explicit JAR files to scan.
* @property classResolver The resolver to use to derive the original name of a requested class.
* @property resolvedUrls The resolved URLs that get passed to the underlying class loader.
*/
open class SourceClassLoader(
paths: List<Path>, paths: List<Path>,
private val classResolver: ClassResolver, private val classResolver: ClassResolver,
val resolvedUrls: Array<URL> = resolvePaths(paths) parent: ClassLoader?
) : URLClassLoader(resolvedUrls, SourceClassLoader::class.java.classLoader) { ) : URLClassLoader(resolvePaths(paths), parent) {
/** /**
* Open a [ClassReader] for the provided class name. * Open a [ClassReader] for the provided class name.
@ -36,7 +30,7 @@ open class SourceClassLoader(
fun classReader( fun classReader(
className: String, context: AnalysisContext, origin: String? = null className: String, context: AnalysisContext, origin: String? = null
): ClassReader { ): ClassReader {
val originalName = classResolver.reverse(className.replace('.', '/')) val originalName = classResolver.reverse(className.asResourcePath)
return try { return try {
logger.trace("Opening ClassReader for class {}...", originalName) logger.trace("Opening ClassReader for class {}...", originalName)
getResourceAsStream("$originalName.class").use { getResourceAsStream("$originalName.class").use {
@ -71,16 +65,16 @@ open class SourceClassLoader(
return super.loadClass(originalName, resolve) return super.loadClass(originalName, resolve)
} }
private companion object { protected companion object {
@JvmStatic
private val logger = loggerFor<SourceClassLoader>() protected val logger = loggerFor<SourceClassLoader>()
private fun resolvePaths(paths: List<Path>): Array<URL> { private fun resolvePaths(paths: List<Path>): Array<URL> {
return paths.map(this::expandPath).flatMap { return paths.map(this::expandPath).flatMap {
when { when {
!Files.exists(it) -> throw FileNotFoundException("File not found; $it") !Files.exists(it) -> throw FileNotFoundException("File not found; $it")
Files.isDirectory(it) -> { Files.isDirectory(it) -> {
listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { it.toURL() }.toList() listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { jar -> jar.toURL() }.toList()
} }
Files.isReadable(it) && isJarFile(it) -> listOf(it.toURL()) Files.isReadable(it) && isJarFile(it) -> listOf(it.toURL())
else -> throw IllegalArgumentException("Expected JAR or class file, but found $it") else -> throw IllegalArgumentException("Expected JAR or class file, but found $it")
@ -100,7 +94,7 @@ open class SourceClassLoader(
private fun isJarFile(path: Path) = path.toString().endsWith(".jar", true) private fun isJarFile(path: Path) = path.toString().endsWith(".jar", true)
private fun Path.toURL() = this.toUri().toURL() private fun Path.toURL(): URL = this.toUri().toURL()
private val homeDirectory: Path private val homeDirectory: Path
get() = Paths.get(System.getProperty("user.home")) get() = Paths.get(System.getProperty("user.home"))
@ -108,3 +102,68 @@ open class SourceClassLoader(
} }
} }
/**
* Class loader to manage an optional JAR of replacement Java APIs.
* @param bootstrapJar The location of the JAR containing the Java APIs.
* @param classResolver The resolver to use to derive the original name of a requested class.
*/
class BootstrapClassLoader(
bootstrapJar: Path,
classResolver: ClassResolver
) : AbstractSourceClassLoader(listOf(bootstrapJar), classResolver, null) {
/**
* Only search our own jars for the given resource.
*/
override fun getResource(name: String): URL? = findResource(name)
}
/**
* Customizable class loader that allows the user to explicitly specify additional JARs and directories to scan.
*
* @param paths The directories and explicit JAR files to scan.
* @property classResolver The resolver to use to derive the original name of a requested class.
* @property bootstrap The [BootstrapClassLoader] containing the Java APIs for the sandbox.
*/
class SourceClassLoader(
paths: List<Path>,
classResolver: ClassResolver,
private val bootstrap: BootstrapClassLoader? = null
) : AbstractSourceClassLoader(paths, classResolver, SourceClassLoader::class.java.classLoader) {
/**
* First check the bootstrap classloader, if we have one.
* Otherwise check our parent classloader, followed by
* the user-supplied jars.
*/
override fun getResource(name: String): URL? {
if (bootstrap != null) {
val resource = bootstrap.findResource(name)
if (resource != null) {
return resource
} else if (isJvmInternal(name)) {
logger.error("Denying request for actual {}", name)
return null
}
}
return parent?.getResource(name) ?: findResource(name)
}
/**
* Deny all requests for DJVM classes from any user-supplied jars.
*/
override fun findResource(name: String): URL? {
return if (name.startsWith("net/corda/djvm/")) null else super.findResource(name)
}
/**
* Does [name] exist within any of the packages reserved for Java itself?
*/
private fun isJvmInternal(name: String): Boolean = name.startsWith("java/")
|| name.startsWith("javax/")
|| name.startsWith("com/sun/")
|| name.startsWith("sun/")
|| name.startsWith("jdk/")
}

View File

@ -1,219 +0,0 @@
package net.corda.djvm.validation
import net.corda.djvm.analysis.AnalysisConfiguration
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassAndMemberVisitor
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.formatting.MemberFormatter
import net.corda.djvm.messages.Message
import net.corda.djvm.messages.Severity
import net.corda.djvm.references.*
import net.corda.djvm.rewiring.SandboxClassLoadingException
import net.corda.djvm.utilities.loggerFor
/**
* Module used to validate all traversable references before instantiating and executing a [SandboxedRunnable].
*
* @param configuration The analysis configuration to use for the validation.
* @property memberFormatter Module with functionality for formatting class members.
*/
class ReferenceValidator(
private val configuration: AnalysisConfiguration,
private val memberFormatter: MemberFormatter = MemberFormatter()
) {
/**
* Container holding the current state of the validation.
*
* @property context The context in which references are to be validated.
* @property analyzer Underlying analyzer used for processing classes.
*/
private class State(
val context: AnalysisContext,
val analyzer: ClassAndMemberVisitor
)
/**
* Validate whether or not the classes in a class hierarchy can be safely instantiated and run in a sandbox by
* checking that all references are rooted in deterministic code.
*
* @param context The context in which the check should be made.
* @param analyzer Underlying analyzer used for processing classes.
*/
fun validate(context: AnalysisContext, analyzer: ClassAndMemberVisitor): ReferenceValidationSummary =
State(context, analyzer).let { state ->
logger.trace("Validating {} references across {} class(es)...",
context.references.size, context.classes.size)
context.references.process { validateReference(state, it) }
logger.trace("Reference validation completed; {} class(es) and {} message(s)",
context.references.size, context.classes.size)
ReferenceValidationSummary(state.context.classes, state.context.messages, state.context.classOrigins)
}
/**
* Construct a message from an invalid reference and its source location.
*/
private fun referenceToMessage(referenceWithLocation: ReferenceWithLocation): Message {
val (location, reference, description) = referenceWithLocation
val referenceMessage = when {
reference is ClassReference ->
"Invalid reference to class ${configuration.classModule.getFormattedClassName(reference.className)}"
reference is MemberReference && configuration.memberModule.isConstructor(reference) ->
"Invalid reference to constructor ${memberFormatter.format(reference)}"
reference is MemberReference && configuration.memberModule.isField(reference) ->
"Invalid reference to field ${memberFormatter.format(reference)}"
reference is MemberReference && configuration.memberModule.isMethod(reference) ->
"Invalid reference to method ${memberFormatter.format(reference)}"
else ->
"Invalid reference to $reference"
}
val message = if (description.isNotBlank()) {
"$referenceMessage, $description"
} else {
referenceMessage
}
return Message(message, Severity.ERROR, location)
}
/**
* Validate a reference made from a class or class member.
*/
private fun validateReference(state: State, reference: EntityReference) {
if (configuration.whitelist.matches(reference.className)) {
// The referenced class has been whitelisted - no need to go any further.
return
}
when (reference) {
is ClassReference -> {
logger.trace("Validating class reference {}", reference)
val clazz = getClass(state, reference.className)
val reason = when (clazz) {
null -> Reason(Reason.Code.NON_EXISTENT_CLASS)
else -> getReasonFromEntity(clazz)
}
if (reason != null) {
logger.trace("Recorded invalid class reference to {}; reason = {}", reference, reason)
state.context.messages.addAll(state.context.references.locationsFromReference(reference).map {
referenceToMessage(ReferenceWithLocation(it, reference, reason.description))
})
}
}
is MemberReference -> {
logger.trace("Validating member reference {}", reference)
// Ensure that the dependent class is loaded and analyzed
val clazz = getClass(state, reference.className)
val member = state.context.classes.getMember(
reference.className, reference.memberName, reference.signature
)
val reason = when {
clazz == null -> Reason(Reason.Code.NON_EXISTENT_CLASS)
member == null -> Reason(Reason.Code.NON_EXISTENT_MEMBER)
else -> getReasonFromEntity(state, member)
}
if (reason != null) {
logger.trace("Recorded invalid member reference to {}; reason = {}", reference, reason)
state.context.messages.addAll(state.context.references.locationsFromReference(reference).map {
referenceToMessage(ReferenceWithLocation(it, reference, reason.description))
})
}
}
}
}
/**
* Get a class from the class hierarchy by its binary name.
*/
private fun getClass(state: State, className: String, originClass: String? = null): ClassRepresentation? {
val name = if (configuration.classModule.isArray(className)) {
val arrayType = arrayTypeExtractor.find(className)?.groupValues?.get(1)
when (arrayType) {
null -> "java/lang/Object"
else -> arrayType
}
} else {
className
}
var clazz = state.context.classes[name]
if (clazz == null) {
logger.trace("Loading and analyzing referenced class {}...", name)
val origin = state.context.references
.locationsFromReference(ClassReference(name))
.map { it.className }
.firstOrNull() ?: originClass
state.analyzer.analyze(name, state.context, origin)
clazz = state.context.classes[name]
}
if (clazz == null) {
logger.warn("Failed to load class {}", name)
state.context.messages.add(Message("Referenced class not found; $name", Severity.ERROR))
}
clazz?.apply {
val ancestors = listOf(superClass) + interfaces
for (ancestor in ancestors.filter(String::isNotBlank)) {
getClass(state, ancestor, clazz.name)
}
}
return clazz
}
/**
* Check if a top-level class definition is considered safe or not.
*/
private fun isNonDeterministic(state: State, className: String): Boolean = when {
configuration.whitelist.matches(className) -> false
else -> {
try {
getClass(state, className)?.let {
isNonDeterministic(it)
} ?: true
} catch (exception: SandboxClassLoadingException) {
true // Failed to load the class, which means the class is non-deterministic.
}
}
}
/**
* Check if a top-level class definition is considered safe or not.
*/
private fun isNonDeterministic(clazz: ClassRepresentation) =
getReasonFromEntity(clazz) != null
/**
* Derive what reason to give to the end-user for an invalid class.
*/
private fun getReasonFromEntity(clazz: ClassRepresentation): Reason? = when {
configuration.whitelist.matches(clazz.name) -> null
configuration.whitelist.inNamespace(clazz.name) -> Reason(Reason.Code.NOT_WHITELISTED)
configuration.classModule.isNonDeterministic(clazz) -> Reason(Reason.Code.ANNOTATED)
else -> null
}
/**
* Derive what reason to give to the end-user for an invalid member.
*/
private fun getReasonFromEntity(state: State, member: Member): Reason? = when {
configuration.whitelist.matches(member.reference) -> null
configuration.whitelist.inNamespace(member.reference) -> Reason(Reason.Code.NOT_WHITELISTED)
configuration.memberModule.isNonDeterministic(member) -> Reason(Reason.Code.ANNOTATED)
else -> {
val invalidClasses = configuration.memberModule.findReferencedClasses(member)
.filter { isNonDeterministic(state, it) }
if (invalidClasses.isNotEmpty()) {
Reason(Reason.Code.INVALID_CLASS, invalidClasses.map {
configuration.classModule.getFormattedClassName(it)
})
} else {
null
}
}
}
private companion object {
private val logger = loggerFor<ReferenceValidator>()
private val arrayTypeExtractor = "^\\[*L([^;]+);$".toRegex()
}
}

View File

@ -21,9 +21,9 @@ import org.objectweb.asm.ClassVisitor
*/ */
class RuleValidator( class RuleValidator(
private val rules: List<Rule> = emptyList(), private val rules: List<Rule> = emptyList(),
configuration: AnalysisConfiguration = AnalysisConfiguration(), configuration: AnalysisConfiguration,
classVisitor: ClassVisitor? = null classVisitor: ClassVisitor? = null
) : ClassAndMemberVisitor(classVisitor, configuration = configuration) { ) : ClassAndMemberVisitor(configuration, classVisitor) {
/** /**
* Apply the set of rules to the traversed class and record any violations. * Apply the set of rules to the traversed class and record any violations.

View File

@ -2,7 +2,7 @@ package sandbox.net.corda.djvm.costing
import net.corda.djvm.SandboxRuntimeContext import net.corda.djvm.SandboxRuntimeContext
import net.corda.djvm.costing.RuntimeCostSummary import net.corda.djvm.costing.RuntimeCostSummary
import net.corda.djvm.costing.ThresholdViolationException import org.objectweb.asm.Type
/** /**
* Class for keeping a tally on various runtime metrics, like number of jumps, allocations, invocations, etc. The * Class for keeping a tally on various runtime metrics, like number of jumps, allocations, invocations, etc. The
@ -24,7 +24,8 @@ object RuntimeCostAccounter {
/** /**
* The type name of the [RuntimeCostAccounter] class; referenced from instrumentors. * The type name of the [RuntimeCostAccounter] class; referenced from instrumentors.
*/ */
const val TYPE_NAME: String = "sandbox/net/corda/djvm/costing/RuntimeCostAccounter" @JvmField
val TYPE_NAME: String = Type.getInternalName(this::class.java)
/** /**
* Known / estimated allocation costs. * Known / estimated allocation costs.
@ -35,14 +36,12 @@ object RuntimeCostAccounter {
) )
/** /**
* Re-throw exception if it is of type [ThreadDeath] or [ThresholdViolationException]. * Re-throw exception if it is of type [ThreadDeath] or [VirtualMachineError].
*/ */
@JvmStatic @JvmStatic
fun checkCatch(exception: Throwable) { fun checkCatch(exception: Throwable) {
if (exception is ThreadDeath) { when (exception) {
throw exception is ThreadDeath, is VirtualMachineError -> throw exception
} else if (exception is ThresholdViolationException) {
throw exception
} }
} }

View File

@ -1,4 +1,4 @@
package net.corda.djvm.costing package sandbox.net.corda.djvm.costing
/** /**
* Exception thrown when a sandbox threshold is violated. This will kill the current thread and consequently exit the * Exception thrown when a sandbox threshold is violated. This will kill the current thread and consequently exit the
@ -6,6 +6,4 @@ package net.corda.djvm.costing
* *
* @property message The description of the condition causing the problem. * @property message The description of the condition causing the problem.
*/ */
class ThresholdViolationException( class ThresholdViolationError(override val message: String) : ThreadDeath()
override val message: String
) : ThreadDeath()

View File

@ -0,0 +1,10 @@
package sandbox.net.corda.djvm.rules
/**
* Exception thrown when a sandbox rule is violated at runtime.
* This will kill the current thread and consequently exit the
* sandbox.
*
* @property message The description of the condition causing the problem.
*/
class RuleViolationError(override val message: String) : ThreadDeath()

View File

@ -16,10 +16,16 @@ import net.corda.djvm.rules.Rule
import net.corda.djvm.source.ClassSource import net.corda.djvm.source.ClassSource
import net.corda.djvm.utilities.Discovery import net.corda.djvm.utilities.Discovery
import net.corda.djvm.validation.RuleValidator import net.corda.djvm.validation.RuleValidator
import org.junit.After
import org.junit.Assert.assertEquals
import org.objectweb.asm.ClassReader import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassVisitor import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Type import org.objectweb.asm.Type
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
import java.nio.file.Path
import java.nio.file.Paths
import kotlin.concurrent.thread
import kotlin.reflect.jvm.jvmName
abstract class TestBase { abstract class TestBase {
@ -33,8 +39,10 @@ abstract class TestBase {
val BLANK = emptySet<Any>() val BLANK = emptySet<Any>()
val DEFAULT = (ALL_RULES + ALL_EMITTERS + ALL_DEFINITION_PROVIDERS) val DEFAULT = (ALL_RULES + ALL_EMITTERS + ALL_DEFINITION_PROVIDERS).distinctBy(Any::javaClass)
.toSet().distinctBy { it.javaClass }
val DETERMINISTIC_RT: Path = Paths.get(
System.getProperty("deterministic-rt.path") ?: throw AssertionError("deterministic-rt.path property not set"))
/** /**
* Get the full name of type [T]. * Get the full name of type [T].
@ -46,13 +54,18 @@ abstract class TestBase {
/** /**
* Default analysis configuration. * Default analysis configuration.
*/ */
val configuration = AnalysisConfiguration(Whitelist.MINIMAL) val configuration = AnalysisConfiguration(Whitelist.MINIMAL, bootstrapJar = DETERMINISTIC_RT)
/** /**
* Default analysis context * Default analysis context
*/ */
val context: AnalysisContext val context: AnalysisContext
get() = AnalysisContext.fromConfiguration(configuration, emptyList()) get() = AnalysisContext.fromConfiguration(configuration)
@After
fun destroy() {
configuration.close()
}
/** /**
* Short-hand for analysing and validating a class. * Short-hand for analysing and validating a class.
@ -62,15 +75,16 @@ abstract class TestBase {
noinline block: (RuleValidator.(AnalysisContext) -> Unit) noinline block: (RuleValidator.(AnalysisContext) -> Unit)
) { ) {
val reader = ClassReader(T::class.java.name) val reader = ClassReader(T::class.java.name)
val configuration = AnalysisConfiguration(minimumSeverityLevel = minimumSeverityLevel) AnalysisConfiguration(
val validator = RuleValidator(ALL_RULES, configuration) minimumSeverityLevel = minimumSeverityLevel,
val context = AnalysisContext.fromConfiguration( classPath = listOf(DETERMINISTIC_RT)
configuration, ).use { analysisConfiguration ->
listOf(ClassSource.fromClassName(reader.className)) val validator = RuleValidator(ALL_RULES, analysisConfiguration)
) val context = AnalysisContext.fromConfiguration(analysisConfiguration)
validator.analyze(reader, context) validator.analyze(reader, context)
block(validator, context) block(validator, context)
} }
}
/** /**
* Short-hand for analysing a class. * Short-hand for analysing a class.
@ -113,27 +127,26 @@ abstract class TestBase {
} }
} }
var thrownException: Throwable? = null var thrownException: Throwable? = null
Thread { thread {
try { try {
val pinnedTestClasses = pinnedClasses.map(Type::getInternalName).toSet() val pinnedTestClasses = pinnedClasses.map(Type::getInternalName).toSet()
val analysisConfiguration = AnalysisConfiguration( AnalysisConfiguration(
whitelist = whitelist, whitelist = whitelist,
bootstrapJar = DETERMINISTIC_RT,
additionalPinnedClasses = pinnedTestClasses, additionalPinnedClasses = pinnedTestClasses,
minimumSeverityLevel = minimumSeverityLevel minimumSeverityLevel = minimumSeverityLevel
) ).use { analysisConfiguration ->
SandboxRuntimeContext(SandboxConfiguration.of( SandboxRuntimeContext(SandboxConfiguration.of(
executionProfile, rules, emitters, definitionProviders, enableTracing, analysisConfiguration executionProfile, rules, emitters, definitionProviders, enableTracing, analysisConfiguration
), classSources).use { )).use {
assertThat(runtimeCosts).areZero() assertThat(runtimeCosts).areZero()
action(this) action(this)
} }
}
} catch (exception: Throwable) { } catch (exception: Throwable) {
thrownException = exception thrownException = exception
} }
}.apply { }.join()
start()
join()
}
throw thrownException ?: return throw thrownException ?: return
} }
@ -145,8 +158,12 @@ abstract class TestBase {
/** /**
* Create a new instance of a class using the sandbox class loader. * Create a new instance of a class using the sandbox class loader.
*/ */
inline fun <reified T : Callable> SandboxRuntimeContext.newCallable() = inline fun <reified T : Callable> SandboxRuntimeContext.newCallable(): LoadedClass = loadClass<T>()
classLoader.loadClassAndBytes(ClassSource.fromClassName(T::class.java.name), context)
inline fun <reified T : Any> SandboxRuntimeContext.loadClass(): LoadedClass = loadClass(T::class.jvmName)
fun SandboxRuntimeContext.loadClass(className: String): LoadedClass =
classLoader.loadClassAndBytes(ClassSource.fromClassName(className), context)
/** /**
* Run the entry-point of the loaded [Callable] class. * Run the entry-point of the loaded [Callable] class.
@ -164,6 +181,10 @@ abstract class TestBase {
/** /**
* Stub visitor. * Stub visitor.
*/ */
protected class Visitor : ClassVisitor(ClassAndMemberVisitor.API_VERSION) protected class Writer : ClassWriter(COMPUTE_FRAMES or COMPUTE_MAXS) {
init {
assertEquals(ClassAndMemberVisitor.API_VERSION, api)
}
}
} }

View File

@ -21,7 +21,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test @Test
fun `can traverse classes`() { fun `can traverse classes`() {
val classesVisited = mutableSetOf<ClassRepresentation>() val classesVisited = mutableSetOf<ClassRepresentation>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClass(clazz: ClassRepresentation): ClassRepresentation { override fun visitClass(clazz: ClassRepresentation): ClassRepresentation {
classesVisited.add(clazz) classesVisited.add(clazz)
return clazz return clazz
@ -47,7 +47,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test @Test
fun `can traverse fields`() { fun `can traverse fields`() {
val membersVisited = mutableSetOf<Member>() val membersVisited = mutableSetOf<Member>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitField(clazz: ClassRepresentation, field: Member): Member { override fun visitField(clazz: ClassRepresentation, field: Member): Member {
membersVisited.add(field) membersVisited.add(field)
return field return field
@ -77,7 +77,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test @Test
fun `can traverse methods`() { fun `can traverse methods`() {
val membersVisited = mutableSetOf<Member>() val membersVisited = mutableSetOf<Member>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitMethod(clazz: ClassRepresentation, method: Member): Member { override fun visitMethod(clazz: ClassRepresentation, method: Member): Member {
membersVisited.add(method) membersVisited.add(method)
return method return method
@ -102,7 +102,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test @Test
fun `can traverse class annotations`() { fun `can traverse class annotations`() {
val annotations = mutableSetOf<String>() val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClassAnnotation(clazz: ClassRepresentation, descriptor: String) { override fun visitClassAnnotation(clazz: ClassRepresentation, descriptor: String) {
annotations.add(descriptor) annotations.add(descriptor)
} }
@ -118,9 +118,21 @@ class ClassAndMemberVisitorTest : TestBase() {
private class TestClassWithAnnotations private class TestClassWithAnnotations
@Test @Test
fun `can traverse member annotations`() { fun `cannot traverse member annotations when reading`() {
val annotations = mutableSetOf<String>() val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitMemberAnnotation(clazz: ClassRepresentation, member: Member, descriptor: String) {
annotations.add("${member.memberName}:$descriptor")
}
}
visitor.analyze<TestClassWithMemberAnnotations>(context)
assertThat(annotations).isEmpty()
}
@Test
fun `can traverse member annotations when writing`() {
val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor(configuration, Writer()) {
override fun visitMemberAnnotation(clazz: ClassRepresentation, member: Member, descriptor: String) { override fun visitMemberAnnotation(clazz: ClassRepresentation, member: Member, descriptor: String) {
annotations.add("${member.memberName}:$descriptor") annotations.add("${member.memberName}:$descriptor")
} }
@ -146,7 +158,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test @Test
fun `can traverse class sources`() { fun `can traverse class sources`() {
val sources = mutableSetOf<String>() val sources = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitSource(clazz: ClassRepresentation, source: String) { override fun visitSource(clazz: ClassRepresentation, source: String) {
sources.add(source) sources.add(source)
} }
@ -160,9 +172,21 @@ class ClassAndMemberVisitorTest : TestBase() {
} }
@Test @Test
fun `can traverse instructions`() { fun `does not traverse instructions when reading`() {
val instructions = mutableSetOf<Pair<Member, Instruction>>() val instructions = mutableSetOf<Pair<Member, Instruction>>()
val visitor = object : ClassAndMemberVisitor() { val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
instructions.add(Pair(method, instruction))
}
}
visitor.analyze<TestClassWithCode>(context)
assertThat(instructions).isEmpty()
}
@Test
fun `can traverse instructions when writing`() {
val instructions = mutableSetOf<Pair<Member, Instruction>>()
val visitor = object : ClassAndMemberVisitor(configuration, Writer()) {
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) { override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
instructions.add(Pair(method, instruction)) instructions.add(Pair(method, instruction))
} }

View File

@ -1,66 +0,0 @@
package net.corda.djvm.analysis
import net.corda.djvm.TestBase
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.validation.ReferenceValidator
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class ReferenceValidatorTest : TestBase() {
private fun validator(whitelist: Whitelist = Whitelist.MINIMAL) =
ReferenceValidator(AnalysisConfiguration(whitelist))
@Test
fun `can validate when there are no references`() = analyze { context ->
analyze<EmptyRunnable>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class EmptyRunnable : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return null
}
}
@Test
fun `can validate when there are references`() = analyze { context ->
analyze<RunnableWithReferences>(context)
analyze<TestRandom>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class RunnableWithReferences : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return TestRandom().nextInt()
}
}
private class TestRandom {
external fun nextInt(): Int
}
@Test
fun `can validate when there are transient references`() = analyze { context ->
analyze<RunnableWithTransientReferences>(context)
analyze<ReferencedClass>(context)
analyze<TestRandom>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class RunnableWithTransientReferences : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return ReferencedClass().test()
}
}
private class ReferencedClass {
fun test(): Int {
return TestRandom().nextInt()
}
}
}

View File

@ -1,27 +1,31 @@
package net.corda.djvm.assertions package net.corda.djvm.assertions
import net.corda.djvm.rewiring.LoadedClass import net.corda.djvm.rewiring.LoadedClass
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.*
class AssertiveClassWithByteCode(private val loadedClass: LoadedClass) { class AssertiveClassWithByteCode(private val loadedClass: LoadedClass) {
fun isSandboxed(): AssertiveClassWithByteCode { fun isSandboxed(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.type.name).startsWith("sandbox.") assertThat(loadedClass.type.name).startsWith("sandbox.")
return this return this
} }
fun hasNotBeenModified(): AssertiveClassWithByteCode { fun hasNotBeenModified(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.byteCode.isModified) assertThat(loadedClass.byteCode.isModified)
.`as`("Byte code has been modified") .`as`("Byte code has been modified")
.isEqualTo(false) .isEqualTo(false)
return this return this
} }
fun hasBeenModified(): AssertiveClassWithByteCode { fun hasBeenModified(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.byteCode.isModified) assertThat(loadedClass.byteCode.isModified)
.`as`("Byte code has been modified") .`as`("Byte code has not been modified")
.isEqualTo(true) .isEqualTo(true)
return this return this
} }
fun hasClassName(className: String): AssertiveClassWithByteCode {
assertThat(loadedClass.type.name).isEqualTo(className)
return this
}
} }

View File

@ -10,7 +10,7 @@ open class AssertiveReferenceMap(private val references: ReferenceMap) {
fun hasCount(count: Int): AssertiveReferenceMap { fun hasCount(count: Int): AssertiveReferenceMap {
val allReferences = references.joinToString("\n") { " - $it" } val allReferences = references.joinToString("\n") { " - $it" }
Assertions.assertThat(references.size) Assertions.assertThat(references.numberOfReferences)
.overridingErrorMessage("Expected $count reference(s), found:\n$allReferences") .overridingErrorMessage("Expected $count reference(s), found:\n$allReferences")
.isEqualTo(count) .isEqualTo(count)
return this return this

View File

@ -21,7 +21,7 @@ class ClassMutatorTest : TestBase() {
} }
} }
val context = context val context = context
val mutator = ClassMutator(Visitor(), configuration, listOf(definitionProvider)) val mutator = ClassMutator(Writer(), configuration, listOf(definitionProvider))
mutator.analyze<TestClass>(context) mutator.analyze<TestClass>(context)
assertThat(hasProvidedDefinition).isTrue() assertThat(hasProvidedDefinition).isTrue()
assertThat(context.classes.get<TestClass>().access or ACC_STRICT).isNotEqualTo(0) assertThat(context.classes.get<TestClass>().access or ACC_STRICT).isNotEqualTo(0)
@ -39,7 +39,7 @@ class ClassMutatorTest : TestBase() {
} }
} }
val context = context val context = context
val mutator = ClassMutator(Visitor(), configuration, listOf(definitionProvider)) val mutator = ClassMutator(Writer(), configuration, listOf(definitionProvider))
mutator.analyze<TestClassWithMembers>(context) mutator.analyze<TestClassWithMembers>(context)
assertThat(hasProvidedDefinition).isTrue() assertThat(hasProvidedDefinition).isTrue()
for (member in context.classes.get<TestClassWithMembers>().members.values) { for (member in context.classes.get<TestClassWithMembers>().members.values) {
@ -67,7 +67,7 @@ class ClassMutatorTest : TestBase() {
} }
} }
val context = context val context = context
val mutator = ClassMutator(Visitor(), configuration, emitters = listOf(emitter)) val mutator = ClassMutator(Writer(), configuration, emitters = listOf(emitter))
mutator.analyze<TestClassWithMembers>(context) mutator.analyze<TestClassWithMembers>(context)
assertThat(hasEmittedCode).isTrue() assertThat(hasEmittedCode).isTrue()
assertThat(shouldPreventDefault).isTrue() assertThat(shouldPreventDefault).isTrue()

View File

@ -7,6 +7,7 @@ import org.junit.Test
import org.objectweb.asm.ClassVisitor import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.NEW import org.objectweb.asm.Opcodes.NEW
import org.objectweb.asm.Type
class EmitterModuleTest : TestBase() { class EmitterModuleTest : TestBase() {
@ -14,15 +15,15 @@ class EmitterModuleTest : TestBase() {
fun `can emit code to method body`() { fun `can emit code to method body`() {
var hasEmittedTypeInstruction = false var hasEmittedTypeInstruction = false
val methodVisitor = object : MethodVisitor(ClassAndMemberVisitor.API_VERSION) { val methodVisitor = object : MethodVisitor(ClassAndMemberVisitor.API_VERSION) {
override fun visitTypeInsn(opcode: Int, type: String?) { override fun visitTypeInsn(opcode: Int, type: String) {
if (opcode == NEW && type == java.lang.String::class.java.name) { if (opcode == NEW && type == Type.getInternalName(java.lang.String::class.java)) {
hasEmittedTypeInstruction = true hasEmittedTypeInstruction = true
} }
} }
} }
val visitor = object : ClassVisitor(ClassAndMemberVisitor.API_VERSION) { val visitor = object : ClassVisitor(ClassAndMemberVisitor.API_VERSION) {
override fun visitMethod( override fun visitMethod(
access: Int, name: String?, descriptor: String?, signature: String?, exceptions: Array<out String>? access: Int, name: String, descriptor: String, signature: String?, exceptions: Array<out String>?
): MethodVisitor { ): MethodVisitor {
return methodVisitor return methodVisitor
} }

View File

@ -3,12 +3,13 @@ package net.corda.djvm.costing
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
class RuntimeCostTest { class RuntimeCostTest {
@Test @Test
fun `can increment cost`() { fun `can increment cost`() {
val cost = RuntimeCost(10, { "failed" }) val cost = RuntimeCost(10) { "failed" }
cost.increment(1) cost.increment(1)
assertThat(cost.value).isEqualTo(1) assertThat(cost.value).isEqualTo(1)
} }
@ -16,8 +17,8 @@ class RuntimeCostTest {
@Test @Test
fun `cannot increment cost beyond threshold`() { fun `cannot increment cost beyond threshold`() {
Thread { Thread {
val cost = RuntimeCost(10, { "failed in ${it.name}" }) val cost = RuntimeCost(10) { "failed in ${it.name}" }
assertThatExceptionOfType(ThresholdViolationException::class.java) assertThatExceptionOfType(ThresholdViolationError::class.java)
.isThrownBy { cost.increment(11) } .isThrownBy { cost.increment(11) }
.withMessage("failed in Foo") .withMessage("failed in Foo")
assertThat(cost.value).isEqualTo(11) assertThat(cost.value).isEqualTo(11)

View File

@ -6,13 +6,15 @@ import foo.bar.sandbox.toNumber
import net.corda.djvm.TestBase import net.corda.djvm.TestBase
import net.corda.djvm.analysis.Whitelist import net.corda.djvm.analysis.Whitelist
import net.corda.djvm.assertions.AssertionExtensions.withProblem import net.corda.djvm.assertions.AssertionExtensions.withProblem
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.rewiring.SandboxClassLoadingException import net.corda.djvm.rewiring.SandboxClassLoadingException
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import sandbox.net.corda.djvm.rules.RuleViolationError
import java.nio.file.Files import java.nio.file.Files
import java.util.* import java.util.*
import java.util.function.Function
class SandboxExecutorTest : TestBase() { class SandboxExecutorTest : TestBase() {
@ -24,8 +26,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo("sandbox") assertThat(result).isEqualTo("sandbox")
} }
class TestSandboxedRunnable : SandboxedRunnable<Int, String> { class TestSandboxedRunnable : Function<Int, String> {
override fun run(input: Int): String? { override fun apply(input: Int): String {
return "sandbox" return "sandbox"
} }
} }
@ -42,8 +44,8 @@ class SandboxExecutorTest : TestBase() {
.withMessageContaining("Contract constraint violated") .withMessageContaining("Contract constraint violated")
} }
class Contract : SandboxedRunnable<Transaction?, Unit> { class Contract : Function<Transaction?, Unit> {
override fun run(input: Transaction?) { override fun apply(input: Transaction?) {
throw IllegalArgumentException("Contract constraint violated") throw IllegalArgumentException("Contract constraint violated")
} }
} }
@ -58,8 +60,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo(0xfed_c0de + 2) assertThat(result).isEqualTo(0xfed_c0de + 2)
} }
class TestObjectHashCode : SandboxedRunnable<Int, Int> { class TestObjectHashCode : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
val obj = Object() val obj = Object()
val hash1 = obj.hashCode() val hash1 = obj.hashCode()
val hash2 = obj.hashCode() val hash2 = obj.hashCode()
@ -76,8 +78,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo(0xfed_c0de + 1) assertThat(result).isEqualTo(0xfed_c0de + 1)
} }
class TestObjectHashCodeWithHierarchy : SandboxedRunnable<Int, Int> { class TestObjectHashCodeWithHierarchy : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
val obj = MyObject() val obj = MyObject()
return obj.hashCode() return obj.hashCode()
} }
@ -91,9 +93,9 @@ class SandboxExecutorTest : TestBase() {
.withMessageContaining("terminated due to excessive use of looping") .withMessageContaining("terminated due to excessive use of looping")
} }
class TestThresholdBreach : SandboxedRunnable<Int, Int> { class TestThresholdBreach : Function<Int, Int> {
private var x = 0 private var x = 0
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
for (i in 0..1_000_000) { for (i in 0..1_000_000) {
x += 1 x += 1
} }
@ -109,8 +111,8 @@ class SandboxExecutorTest : TestBase() {
.withCauseInstanceOf(StackOverflowError::class.java) .withCauseInstanceOf(StackOverflowError::class.java)
} }
class TestStackOverflow : SandboxedRunnable<Int, Int> { class TestStackOverflow : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
return a() return a()
} }
@ -124,11 +126,12 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestKotlinMetaClasses>(0) } .isThrownBy { contractExecutor.run<TestKotlinMetaClasses>(0) }
.withMessageContaining("java/util/Random.<clinit>(): Disallowed reference to reflection API; sun.misc.Unsafe.getUnsafe()") .withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to reflection API")
} }
class TestKotlinMetaClasses : SandboxedRunnable<Int, Int> { class TestKotlinMetaClasses : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
val someNumber = testRandom() val someNumber = testRandom()
return "12345".toNumber() * someNumber return "12345".toNumber() * someNumber
} }
@ -139,30 +142,32 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestNonDeterministicCode>(0) } .isThrownBy { contractExecutor.run<TestNonDeterministicCode>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java) .withCauseInstanceOf(RuleViolationError::class.java)
.withProblem("java/util/Random.<clinit>(): Disallowed reference to reflection API; sun.misc.Unsafe.getUnsafe()") .withProblem("Disallowed reference to reflection API")
} }
class TestNonDeterministicCode : SandboxedRunnable<Int, Int> { class TestNonDeterministicCode : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
return Random().nextInt() return Random().nextInt()
} }
} }
@Test @Test
fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT) { fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT) {
TestCatchThreadDeath().apply {
assertThat(apply(0)).isEqualTo(1)
}
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThreadDeath>(0) } .isThrownBy { contractExecutor.run<TestCatchThreadDeath>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java) .withCauseExactlyInstanceOf(ThreadDeath::class.java)
.withMessageContaining("Disallowed catch of ThreadDeath exception")
.withMessageContaining(TestCatchThreadDeath::class.java.simpleName)
} }
class TestCatchThreadDeath : SandboxedRunnable<Int, Int> { class TestCatchThreadDeath : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
return try { return try {
0 throw ThreadDeath()
} catch (exception: ThreadDeath) { } catch (exception: ThreadDeath) {
1 1
} }
@ -170,20 +175,46 @@ class SandboxExecutorTest : TestBase() {
} }
@Test @Test
fun `cannot execute runnable that catches ThresholdViolationException`() = sandbox(DEFAULT) { fun `cannot execute runnable that catches ThresholdViolationError`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) TestCatchThresholdViolationError().apply {
assertThatExceptionOfType(SandboxException::class.java) assertThat(apply(0)).isEqualTo(1)
.isThrownBy { contractExecutor.run<TestCatchThresholdViolationException>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Disallowed catch of threshold violation exception")
.withMessageContaining(TestCatchThresholdViolationException::class.java.simpleName)
} }
class TestCatchThresholdViolationException : SandboxedRunnable<Int, Int> { val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
override fun run(input: Int): Int? { assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThresholdViolationError>(0) }
.withCauseExactlyInstanceOf(ThresholdViolationError::class.java)
.withMessageContaining("Can't catch this!")
}
class TestCatchThresholdViolationError : Function<Int, Int> {
override fun apply(input: Int): Int {
return try { return try {
0 throw ThresholdViolationError("Can't catch this!")
} catch (exception: ThresholdViolationException) { } catch (exception: ThresholdViolationError) {
1
}
}
}
@Test
fun `cannot execute runnable that catches RuleViolationError`() = sandbox(DEFAULT) {
TestCatchRuleViolationError().apply {
assertThat(apply(0)).isEqualTo(1)
}
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchRuleViolationError>(0) }
.withCauseExactlyInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Can't catch this!")
}
class TestCatchRuleViolationError : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
throw RuleViolationError("Can't catch this!")
} catch (exception: RuleViolationError) {
1 1
} }
} }
@ -209,12 +240,12 @@ class SandboxExecutorTest : TestBase() {
fun `cannot catch ThreadDeath`() = sandbox(DEFAULT) { fun `cannot catch ThreadDeath`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorAndThreadDeath>(3) } .isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(3) }
.withCauseInstanceOf(ThreadDeath::class.java) .withCauseInstanceOf(ThreadDeath::class.java)
} }
class TestCatchThrowableAndError : SandboxedRunnable<Int, Int> { class TestCatchThrowableAndError : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
return try { return try {
when (input) { when (input) {
1 -> throw Throwable() 1 -> throw Throwable()
@ -229,13 +260,27 @@ class SandboxExecutorTest : TestBase() {
} }
} }
class TestCatchThrowableErrorAndThreadDeath : SandboxedRunnable<Int, Int> { class TestCatchThrowableErrorsAndThreadDeath : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
return try { return try {
when (input) { when (input) {
1 -> throw Throwable() 1 -> throw Throwable()
2 -> throw Error() 2 -> throw Error()
3 -> throw ThreadDeath() 3 -> try {
throw ThreadDeath()
} catch (ex: ThreadDeath) {
3
}
4 -> try {
throw StackOverflowError("FAKE OVERFLOW!")
} catch (ex: StackOverflowError) {
4
}
5 -> try {
throw OutOfMemoryError("FAKE OOM!")
} catch (ex: OutOfMemoryError) {
5
}
else -> 0 else -> 0
} }
} catch (exception: Error) { } catch (exception: Error) {
@ -246,6 +291,24 @@ class SandboxExecutorTest : TestBase() {
} }
} }
@Test
fun `cannot catch stack-overflow error`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(4) }
.withCauseInstanceOf(StackOverflowError::class.java)
.withMessageContaining("FAKE OVERFLOW!")
}
@Test
fun `cannot catch out-of-memory error`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(5) }
.withCauseInstanceOf(OutOfMemoryError::class.java)
.withMessageContaining("FAKE OOM!")
}
@Test @Test
fun `cannot persist state across sessions`() = sandbox(DEFAULT) { fun `cannot persist state across sessions`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
@ -256,8 +319,8 @@ class SandboxExecutorTest : TestBase() {
.isEqualTo(1) .isEqualTo(1)
} }
class TestStatePersistence : SandboxedRunnable<Int, Int> { class TestStatePersistence : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
ReferencedClass.value += 1 ReferencedClass.value += 1
return ReferencedClass.value return ReferencedClass.value
} }
@ -274,11 +337,11 @@ class SandboxExecutorTest : TestBase() {
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestIO>(0) } .isThrownBy { contractExecutor.run<TestIO>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java) .withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Files.walk(Path, Integer, FileVisitOption[]): Disallowed dynamic invocation in method") .withMessageContaining("Class file not found; java/nio/file/Files.class")
} }
class TestIO : SandboxedRunnable<Int, Int> { class TestIO : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
val file = Files.createTempFile("test", ".dat") val file = Files.createTempFile("test", ".dat")
Files.newBufferedWriter(file).use { Files.newBufferedWriter(file).use {
it.write("Hello world!") it.write("Hello world!")
@ -292,14 +355,13 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration) val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java) assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestReflection>(0) } .isThrownBy { contractExecutor.run<TestReflection>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java) .withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to reflection API") .withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Class.newInstance()") .withMessageContaining("java.lang.Class.newInstance()")
.withMessageContaining("java.lang.reflect.Method.invoke(Object, Object[])")
} }
class TestReflection : SandboxedRunnable<Int, Int> { class TestReflection : Function<Int, Int> {
override fun run(input: Int): Int? { override fun apply(input: Int): Int {
val clazz = Object::class.java val clazz = Object::class.java
val obj = clazz.newInstance() val obj = clazz.newInstance()
val result = clazz.methods.first().invoke(obj) val result = clazz.methods.first().invoke(obj)
@ -307,4 +369,150 @@ class SandboxExecutorTest : TestBase() {
} }
} }
@Test
fun `can load and execute code that uses notify()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(1) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.notify()")
}
@Test
fun `can load and execute code that uses notifyAll()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(2) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.notifyAll()")
}
@Test
fun `can load and execute code that uses wait()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(3) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait()")
}
@Test
fun `can load and execute code that uses wait(long)`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(4) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait(Long)")
}
@Test
fun `can load and execute code that uses wait(long,int)`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(5) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait(Long, Integer)")
}
@Test
fun `code after forbidden APIs is intact`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThat(contractExecutor.run<TestMonitors>(0).result)
.isEqualTo("unknown")
}
class TestMonitors : Function<Int, String> {
override fun apply(input: Int): String {
return synchronized(this) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val javaObject = this as java.lang.Object
when(input) {
1 -> {
javaObject.notify()
"notify"
}
2 -> {
javaObject.notifyAll()
"notifyAll"
}
3 -> {
javaObject.wait()
"wait"
}
4 -> {
javaObject.wait(100)
"wait(100)"
}
5 -> {
javaObject.wait(100, 10)
"wait(100, 10)"
}
else -> "unknown"
}
}
}
}
@Test
fun `can load and execute code that has a native method`() = sandbox(DEFAULT) {
assertThatExceptionOfType(UnsatisfiedLinkError::class.java)
.isThrownBy { TestNativeMethod().apply(0) }
.withMessageContaining("TestNativeMethod.evilDeeds()I")
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestNativeMethod>(0) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Native method has been deleted")
}
class TestNativeMethod : Function<Int, Int> {
override fun apply(input: Int): Int {
return evilDeeds()
}
private external fun evilDeeds(): Int
}
@Test
fun `check arrays still work`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Array<Int>>(configuration)
contractExecutor.run<TestArray>(100).apply {
assertThat(result).isEqualTo(arrayOf(100))
}
}
class TestArray : Function<Int, Array<Int>> {
override fun apply(input: Int): Array<Int> {
return listOf(input).toTypedArray()
}
}
@Test
fun `can load and execute class that has finalize`() = sandbox(DEFAULT) {
assertThatExceptionOfType(UnsupportedOperationException::class.java)
.isThrownBy { TestFinalizeMethod().apply(100) }
.withMessageContaining("Very Bad Thing")
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
contractExecutor.run<TestFinalizeMethod>(100).apply {
assertThat(result).isEqualTo(100)
}
}
class TestFinalizeMethod : Function<Int, Int> {
override fun apply(input: Int): Int {
finalize()
return input
}
private fun finalize() {
throw UnsupportedOperationException("Very Bad Thing")
}
}
} }

View File

@ -6,10 +6,11 @@ import foo.bar.sandbox.Empty
import foo.bar.sandbox.StrictFloat import foo.bar.sandbox.StrictFloat
import net.corda.djvm.TestBase import net.corda.djvm.TestBase
import net.corda.djvm.assertions.AssertionExtensions.assertThat import net.corda.djvm.assertions.AssertionExtensions.assertThat
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.execution.ExecutionProfile import net.corda.djvm.execution.ExecutionProfile
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import java.nio.file.Paths
class ClassRewriterTest : TestBase() { class ClassRewriterTest : TestBase() {
@ -45,7 +46,7 @@ class ClassRewriterTest : TestBase() {
val callable = newCallable<B>() val callable = newCallable<B>()
assertThat(callable).hasBeenModified() assertThat(callable).hasBeenModified()
assertThat(callable).isSandboxed() assertThat(callable).isSandboxed()
assertThatExceptionOfType(ThresholdViolationException::class.java).isThrownBy { assertThatExceptionOfType(ThresholdViolationError::class.java).isThrownBy {
callable.createAndInvoke() callable.createAndInvoke()
}.withMessageContaining("terminated due to excessive use of looping") }.withMessageContaining("terminated due to excessive use of looping")
assertThat(runtimeCosts) assertThat(runtimeCosts)
@ -61,4 +62,44 @@ class ClassRewriterTest : TestBase() {
callable.createAndInvoke() callable.createAndInvoke()
} }
@Test
fun `can load a Java API that still exists in Java runtime`() = sandbox(DEFAULT) {
assertThat(loadClass<MutableList<*>>())
.hasClassName("sandbox.java.util.List")
.hasBeenModified()
}
@Test
fun `cannot load a Java API that was deleted from Java runtime`() = sandbox(DEFAULT) {
assertThatExceptionOfType(SandboxClassLoadingException::class.java)
.isThrownBy { loadClass<Paths>() }
.withMessageContaining("Class file not found; java/nio/file/Paths.class")
}
@Test
fun `load internal Sun class that still exists in Java runtime`() = sandbox(DEFAULT) {
assertThat(loadClass<sun.misc.Unsafe>())
.hasClassName("sandbox.sun.misc.Unsafe")
.hasBeenModified()
}
@Test
fun `cannot load internal Sun class that was deleted from Java runtime`() = sandbox(DEFAULT) {
assertThatExceptionOfType(SandboxClassLoadingException::class.java)
.isThrownBy { loadClass<sun.misc.Timer>() }
.withMessageContaining("Class file not found; sun/misc/Timer.class")
}
@Test
fun `can load local class`() = sandbox(DEFAULT) {
assertThat(loadClass<Example>())
.hasClassName("sandbox.net.corda.djvm.rewiring.ClassRewriterTest\$Example")
.hasBeenModified()
}
class Example : java.util.function.Function<Int, Int> {
override fun apply(input: Int): Int {
return input
}
}
} }

View File

@ -9,19 +9,6 @@ import java.util.*
class ReferenceExtractorTest : TestBase() { class ReferenceExtractorTest : TestBase() {
@Test
fun `can find method references`() = validate<A> { context ->
assertThat(context.references)
.hasClass("java/util/Random")
.withLocationCount(1)
.hasMember("java/lang/Object", "<init>", "()V")
.withLocationCount(1)
.hasMember("java/util/Random", "<init>", "()V")
.withLocationCount(1)
.hasMember("java/util/Random", "nextInt", "()I")
.withLocationCount(1)
}
class A : Callable { class A : Callable {
override fun call() { override fun call() {
synchronized(this) { synchronized(this) {
@ -30,21 +17,6 @@ class ReferenceExtractorTest : TestBase() {
} }
} }
@Test
fun `can find field references`() = validate<B> { context ->
assertThat(context.references)
.hasMember(Type.getInternalName(B::class.java), "foo", "Ljava/lang/String;")
}
class B {
@JvmField
val foo: String = ""
fun test(): String {
return foo
}
}
@Test @Test
fun `can find class references`() = validate<C> { context -> fun `can find class references`() = validate<C> { context ->
assertThat(context.references) assertThat(context.references)

View File

@ -42,8 +42,7 @@ class RuleValidatorTest : TestBase() {
assertThat(context.messages) assertThat(context.messages)
.hasErrorCount(0) .hasErrorCount(0)
.hasWarningCount(0) .hasWarningCount(0)
.hasInfoCount(1) .hasInfoCount(0)
.withMessage("Stripped monitoring instruction")
.hasTraceCount(4) .hasTraceCount(4)
.withMessage("Synchronization specifier will be ignored") .withMessage("Synchronization specifier will be ignored")
.withMessage("Strict floating-point arithmetic will be applied") .withMessage("Strict floating-point arithmetic will be applied")

View File

@ -67,7 +67,7 @@ class SourceClassLoaderTest {
val (first, second) = this val (first, second) = this
val directory = first.parent val directory = first.parent
val classLoader = SourceClassLoader(listOf(directory), classResolver) val classLoader = SourceClassLoader(listOf(directory), classResolver)
assertThat(classLoader.resolvedUrls).anySatisfy { assertThat(classLoader.urLs).anySatisfy {
assertThat(it).isEqualTo(first.toUri().toURL()) assertThat(it).isEqualTo(first.toUri().toURL())
}.anySatisfy { }.anySatisfy {
assertThat(it).isEqualTo(second.toUri().toURL()) assertThat(it).isEqualTo(second.toUri().toURL())