ENT-1906: Update DJVM to load deterministic rt.jar into sandbox. (#3973)

* Update DJVM to load deterministic rt.jar into sandbox.
* Disallow invocations of notify(), notifyAll() and wait() APIs.
* Pass entire derivedMember to MemberVisitorImpl.
* Updates after review.
* Refactor MethodBody handlers to use EmitterModule.
This commit is contained in:
Chris Rankin 2018-09-27 17:28:22 +01:00 committed by GitHub
parent e92ad538cf
commit d35a47bc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1006 additions and 810 deletions

View File

@ -3,6 +3,7 @@ plugins {
}
apply plugin: 'net.corda.plugins.publish-utils'
apply plugin: 'com.jfrog.artifactory'
apply plugin: 'idea'
description 'Corda deterministic JVM sandbox'
@ -11,8 +12,18 @@ ext {
asm_version = '6.1.1'
}
repositories {
maven {
url "$artifactory_contextUrl/corda-dev"
}
}
configurations {
testCompile.extendsFrom shadow
jdkRt.resolutionStrategy {
// Always check the repository for a newer SNAPSHOT.
cacheChangingModulesFor 0, 'seconds'
}
}
dependencies {
@ -32,6 +43,7 @@ dependencies {
testCompile "junit:junit:$junit_version"
testCompile "org.assertj:assertj-core:$assertj_version"
testCompile "org.apache.logging.log4j:log4j-slf4j-impl:$log4j_version"
jdkRt "net.corda:deterministic-rt:latest.integration"
}
jar.enabled = false
@ -43,6 +55,10 @@ shadowJar {
}
assemble.dependsOn shadowJar
tasks.withType(Test) {
systemProperty 'deterministic-rt.path', configurations.jdkRt.asPath
}
artifacts {
publish shadowJar
}
@ -51,3 +67,10 @@ publish {
dependenciesFrom configurations.shadow
name shadowJar.baseName
}
idea {
module {
downloadJavadoc = true
downloadSources = true
}
}

View File

@ -63,7 +63,7 @@ abstract class ClassCommand : CommandBase() {
private lateinit var classLoader: ClassLoader
protected var executor = SandboxExecutor<Any, Any>()
protected var executor = SandboxExecutor<Any, Any>(SandboxConfiguration.DEFAULT)
private var derivedWhitelist: Whitelist = Whitelist.MINIMAL
@ -114,7 +114,7 @@ abstract class ClassCommand : CommandBase() {
}
private fun findDiscoverableRunnables(filters: Array<String>): List<Class<*>> {
val classes = find<DiscoverableRunnable>()
val classes = find<java.util.function.Function<*,*>>()
val applicableFilters = filters
.filter { !isJarFile(it) && !isFullClassName(it) }
val filteredClasses = applicableFilters
@ -125,7 +125,7 @@ abstract class ClassCommand : CommandBase() {
}
if (applicableFilters.isNotEmpty() && filteredClasses.isEmpty()) {
throw Exception("Could not find any classes implementing ${SandboxedRunnable::class.java.simpleName} " +
throw Exception("Could not find any classes implementing ${java.util.function.Function::class.java.simpleName} " +
"whose name matches '${applicableFilters.joinToString(" ")}'")
}
@ -189,7 +189,7 @@ abstract class ClassCommand : CommandBase() {
profile = profile,
rules = if (ignoreRules) { emptyList() } else { Discovery.find() },
emitters = ignoreEmitters.emptyListIfTrueOtherwiseNull(),
definitionProviders = if(ignoreDefinitionProviders) { emptyList() } else { Discovery.find() },
definitionProviders = if (ignoreDefinitionProviders) { emptyList() } else { Discovery.find() },
enableTracing = !disableTracing,
analysisConfiguration = AnalysisConfiguration(
whitelist = whitelist,

View File

@ -1,6 +1,5 @@
package net.corda.djvm.tools.cli
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.source.ClassSource
import picocli.CommandLine.Command
import picocli.CommandLine.Parameters
@ -20,7 +19,7 @@ class RunCommand : ClassCommand() {
var classes: Array<String> = emptyArray()
override fun processClasses(classes: List<Class<*>>) {
val interfaceName = SandboxedRunnable::class.java.simpleName
val interfaceName = java.util.function.Function::class.java.simpleName
for (clazz in classes) {
if (!clazz.interfaces.any { it.simpleName == interfaceName }) {
printError("Class is not an instance of $interfaceName; ${clazz.name}")

View File

@ -33,52 +33,53 @@ class WhitelistGenerateCommand : CommandBase() {
override fun validateArguments() = paths.isNotEmpty()
override fun handleCommand(): Boolean {
val entries = mutableListOf<String>()
val visitor = object : ClassAndMemberVisitor() {
override fun visitClass(clazz: ClassRepresentation): ClassRepresentation {
entries.add(clazz.name)
return super.visitClass(clazz)
}
val entries = AnalysisConfiguration().use { configuration ->
val entries = mutableListOf<String>()
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClass(clazz: ClassRepresentation): ClassRepresentation {
entries.add(clazz.name)
return super.visitClass(clazz)
}
override fun visitMethod(clazz: ClassRepresentation, method: Member): Member {
visitMember(clazz, method)
return super.visitMethod(clazz, method)
}
override fun visitMethod(clazz: ClassRepresentation, method: Member): Member {
visitMember(clazz, method)
return super.visitMethod(clazz, method)
}
override fun visitField(clazz: ClassRepresentation, field: Member): Member {
visitMember(clazz, field)
return super.visitField(clazz, field)
}
override fun visitField(clazz: ClassRepresentation, field: Member): Member {
visitMember(clazz, field)
return super.visitField(clazz, field)
}
private fun visitMember(clazz: ClassRepresentation, member: Member) {
entries.add("${clazz.name}.${member.memberName}:${member.signature}")
private fun visitMember(clazz: ClassRepresentation, member: Member) {
entries.add("${clazz.name}.${member.memberName}:${member.signature}")
}
}
val context = AnalysisContext.fromConfiguration(configuration)
for (path in paths) {
ClassSource.fromPath(path).getStreamIterator().forEach {
visitor.analyze(it, context)
}
}
entries
}
val context = AnalysisContext.fromConfiguration(AnalysisConfiguration(), emptyList())
for (path in paths) {
ClassSource.fromPath(path).getStreamIterator().forEach {
visitor.analyze(it, context)
}
}
val output = output
if (output != null) {
Files.newOutputStream(output, StandardOpenOption.CREATE).use {
GZIPOutputStream(it).use {
PrintStream(it).use {
it.println("""
output?.also {
Files.newOutputStream(it, StandardOpenOption.CREATE).use { out ->
GZIPOutputStream(out).use { gzip ->
PrintStream(gzip).use { pout ->
pout.println("""
|java/.*
|javax/.*
|jdk/.*
|com/sun/.*
|sun/.*
|---
""".trimMargin().trim())
printEntries(it, entries)
printEntries(pout, entries)
}
}
}
} else {
printEntries(System.out, entries)
}
} ?: printEntries(System.out, entries)
return true
}

View File

@ -3,25 +3,20 @@ package net.corda.djvm
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.costing.RuntimeCostSummary
import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.djvm.source.ClassSource
/**
* The context in which a sandboxed operation is run.
*
* @property configuration The configuration of the sandbox.
* @property inputClasses The classes passed in for analysis.
*/
class SandboxRuntimeContext(
val configuration: SandboxConfiguration,
private val inputClasses: List<ClassSource>
) {
class SandboxRuntimeContext(val configuration: SandboxConfiguration) {
/**
* The class loader to use inside the sandbox.
*/
val classLoader: SandboxClassLoader = SandboxClassLoader(
configuration,
AnalysisContext.fromConfiguration(configuration.analysisConfiguration, inputClasses)
AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
)
/**
@ -35,7 +30,7 @@ class SandboxRuntimeContext(
fun use(action: SandboxRuntimeContext.() -> Unit) {
SandboxRuntimeContext.instance = this
try {
this.action()
action(this)
} finally {
threadLocalContext.remove()
}
@ -43,9 +38,7 @@ class SandboxRuntimeContext(
companion object {
private val threadLocalContext = object : ThreadLocal<SandboxRuntimeContext?>() {
override fun initialValue(): SandboxRuntimeContext? = null
}
private val threadLocalContext = ThreadLocal<SandboxRuntimeContext?>()
/**
* When called from within a sandbox, this returns the context for the current sandbox thread.

View File

@ -1,9 +1,15 @@
package net.corda.djvm.analysis
import net.corda.djvm.code.ruleViolationError
import net.corda.djvm.code.thresholdViolationError
import net.corda.djvm.messages.Severity
import net.corda.djvm.references.ClassModule
import net.corda.djvm.references.MemberModule
import net.corda.djvm.source.BootstrapClassLoader
import net.corda.djvm.source.SourceClassLoader
import sandbox.net.corda.djvm.costing.RuntimeCostAccounter
import java.io.Closeable
import java.io.IOException
import java.nio.file.Path
/**
@ -13,7 +19,8 @@ import java.nio.file.Path
* @param additionalPinnedClasses Classes that have already been declared in the sandbox namespace and that should be
* made available inside the sandboxed environment.
* @property minimumSeverityLevel The minimum severity level to log and report.
* @property classPath The extended class path to use for the analysis.
* @param classPath The extended class path to use for the analysis.
* @param bootstrapJar The location of a jar containing the Java APIs.
* @property analyzeAnnotations Analyze annotations despite not being explicitly referenced.
* @property prefixFilters Only record messages where the originating class name matches one of the provided prefixes.
* If none are provided, all messages will be reported.
@ -24,32 +31,47 @@ class AnalysisConfiguration(
val whitelist: Whitelist = Whitelist.MINIMAL,
additionalPinnedClasses: Set<String> = emptySet(),
val minimumSeverityLevel: Severity = Severity.WARNING,
val classPath: List<Path> = emptyList(),
classPath: List<Path> = emptyList(),
bootstrapJar: Path? = null,
val analyzeAnnotations: Boolean = false,
val prefixFilters: List<String> = emptyList(),
val classModule: ClassModule = ClassModule(),
val memberModule: MemberModule = MemberModule()
) {
) : Closeable {
/**
* Classes that have already been declared in the sandbox namespace and that should be made
* available inside the sandboxed environment.
*/
val pinnedClasses: Set<String> = setOf(SANDBOXED_OBJECT, RUNTIME_COST_ACCOUNTER) + additionalPinnedClasses
val pinnedClasses: Set<String> = setOf(
SANDBOXED_OBJECT,
RuntimeCostAccounter.TYPE_NAME,
ruleViolationError,
thresholdViolationError
) + additionalPinnedClasses
/**
* Functionality used to resolve the qualified name and relevant information about a class.
*/
val classResolver: ClassResolver = ClassResolver(pinnedClasses, whitelist, SANDBOX_PREFIX)
private val bootstrapClassLoader = bootstrapJar?.let { BootstrapClassLoader(it, classResolver) }
val supportingClassLoader = SourceClassLoader(classPath, classResolver, bootstrapClassLoader)
@Throws(IOException::class)
override fun close() {
supportingClassLoader.use {
bootstrapClassLoader?.close()
}
}
companion object {
/**
* The package name prefix to use for classes loaded into a sandbox.
*/
private const val SANDBOX_PREFIX: String = "sandbox/"
private const val SANDBOXED_OBJECT = "sandbox/java/lang/Object"
private const val RUNTIME_COST_ACCOUNTER = RuntimeCostAccounter.TYPE_NAME
private const val SANDBOXED_OBJECT = SANDBOX_PREFIX + "java/lang/Object"
}
}

View File

@ -1,10 +1,10 @@
package net.corda.djvm.analysis
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.messages.MessageCollection
import net.corda.djvm.references.ClassHierarchy
import net.corda.djvm.references.EntityReference
import net.corda.djvm.references.ReferenceMap
import net.corda.djvm.source.ClassSource
/**
* The context in which one or more classes are analysed.
@ -13,13 +13,11 @@ import net.corda.djvm.source.ClassSource
* @property classes List of class definitions that have been analyzed.
* @property references A collection of all referenced members found during analysis together with the locations from
* where each member has been accessed or invoked.
* @property inputClasses The classes passed in for analysis.
*/
class AnalysisContext private constructor(
val messages: MessageCollection,
val classes: ClassHierarchy,
val references: ReferenceMap,
val inputClasses: List<ClassSource>
val references: ReferenceMap
) {
private val origins = mutableMapOf<String, MutableSet<EntityReference>>()
@ -28,7 +26,7 @@ class AnalysisContext private constructor(
* Record a class origin in the current analysis context.
*/
fun recordClassOrigin(name: String, origin: EntityReference) {
origins.getOrPut(name.normalize()) { mutableSetOf() }.add(origin)
origins.getOrPut(name.asPackagePath) { mutableSetOf() }.add(origin)
}
/**
@ -42,20 +40,14 @@ class AnalysisContext private constructor(
/**
* Create a new analysis context from provided configuration.
*/
fun fromConfiguration(configuration: AnalysisConfiguration, classes: List<ClassSource>): AnalysisContext {
fun fromConfiguration(configuration: AnalysisConfiguration): AnalysisContext {
return AnalysisContext(
MessageCollection(configuration.minimumSeverityLevel, configuration.prefixFilters),
ClassHierarchy(configuration.classModule, configuration.memberModule),
ReferenceMap(configuration.classModule),
classes
ReferenceMap(configuration.classModule)
)
}
/**
* Local extension method for normalizing a class name.
*/
private fun String.normalize() = this.replace("/", ".")
}
}

View File

@ -4,30 +4,25 @@ import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.*
import net.corda.djvm.messages.Message
import net.corda.djvm.references.ClassReference
import net.corda.djvm.references.ClassRepresentation
import net.corda.djvm.references.Member
import net.corda.djvm.references.MemberReference
import net.corda.djvm.source.SourceClassLoader
import net.corda.djvm.references.*
import org.objectweb.asm.*
import java.io.InputStream
/**
* Functionality for traversing a class and its members.
*
* @property classVisitor Class visitor to use when traversing the structure of classes.
* @property configuration The configuration to use for the analysis
* @property classVisitor Class visitor to use when traversing the structure of classes.
*/
open class ClassAndMemberVisitor(
private val classVisitor: ClassVisitor? = null,
private val configuration: AnalysisConfiguration = AnalysisConfiguration()
private val configuration: AnalysisConfiguration,
private val classVisitor: ClassVisitor?
) {
/**
* Holds a reference to the currently used analysis context.
*/
protected var analysisContext: AnalysisContext =
AnalysisContext.fromConfiguration(configuration, emptyList())
protected var analysisContext: AnalysisContext = AnalysisContext.fromConfiguration(configuration)
/**
* Holds a link to the class currently being traversed.
@ -44,12 +39,6 @@ open class ClassAndMemberVisitor(
*/
private var sourceLocation = SourceLocation()
/**
* The class loader used to find classes on the extended class path.
*/
private val supportingClassLoader =
SourceClassLoader(configuration.classPath, configuration.classResolver)
/**
* Analyze class by using the provided qualified name of the class.
*/
@ -63,7 +52,7 @@ open class ClassAndMemberVisitor(
* @param origin The originating class for the analysis.
*/
fun analyze(className: String, context: AnalysisContext, origin: String? = null) {
supportingClassLoader.classReader(className, context, origin).apply {
configuration.supportingClassLoader.classReader(className, context, origin).apply {
analyze(this, context)
}
}
@ -167,7 +156,8 @@ open class ClassAndMemberVisitor(
}
/**
* Run action with a guard that populates [messages] based on the output.
* Run action with a guard that populates [AnalysisRuntimeContext.messages]
* based on the output.
*/
private inline fun captureExceptions(action: () -> Unit): Boolean {
return try {
@ -229,9 +219,7 @@ open class ClassAndMemberVisitor(
ClassRepresentation(version, access, name, superClassName, interfaceNames, genericsDetails = signature ?: "").also {
currentClass = it
currentMember = null
sourceLocation = SourceLocation(
className = name
)
sourceLocation = SourceLocation(className = name)
}
captureExceptions {
currentClass = visitClass(currentClass!!)
@ -251,7 +239,7 @@ open class ClassAndMemberVisitor(
override fun visitEnd() {
configuration.classModule
.getClassReferencesFromClass(currentClass!!, configuration.analyzeAnnotations)
.forEach { recordTypeReference(it) }
.forEach(::recordTypeReference)
captureExceptions {
visitClassEnd(currentClass!!)
}
@ -306,14 +294,15 @@ open class ClassAndMemberVisitor(
configuration.memberModule.addToClass(clazz, visitedMember ?: member)
return if (processMember) {
val derivedMember = visitedMember ?: member
val targetVisitor = super.visitMethod(
derivedMember.access,
derivedMember.memberName,
derivedMember.signature,
signature,
derivedMember.exceptions.toTypedArray()
)
MethodVisitorImpl(targetVisitor)
super.visitMethod(
derivedMember.access,
derivedMember.memberName,
derivedMember.signature,
signature,
derivedMember.exceptions.toTypedArray()
)?.let { targetVisitor ->
MethodVisitorImpl(targetVisitor, derivedMember)
}
} else {
null
}
@ -340,14 +329,15 @@ open class ClassAndMemberVisitor(
configuration.memberModule.addToClass(clazz, visitedMember ?: member)
return if (processMember) {
val derivedMember = visitedMember ?: member
val targetVisitor = super.visitField(
derivedMember.access,
derivedMember.memberName,
derivedMember.signature,
signature,
derivedMember.value
)
FieldVisitorImpl(targetVisitor)
super.visitField(
derivedMember.access,
derivedMember.memberName,
derivedMember.signature,
signature,
derivedMember.value
)?.let { targetVisitor ->
FieldVisitorImpl(targetVisitor)
}
} else {
null
}
@ -359,7 +349,8 @@ open class ClassAndMemberVisitor(
* Visitor used to traverse and analyze a method.
*/
private inner class MethodVisitorImpl(
targetVisitor: MethodVisitor?
targetVisitor: MethodVisitor,
private val method: Member
) : MethodVisitor(API_VERSION, targetVisitor) {
/**
@ -387,6 +378,16 @@ open class ClassAndMemberVisitor(
return super.visitAnnotation(desc, visible)
}
/**
* Write any new method body code, assuming the definition providers
* have provided any. This handler will not be visited if this method
* has no existing code.
*/
override fun visitCode() {
tryReplaceMethodBody()
super.visitCode()
}
/**
* Extract information about provided field access instruction.
*/
@ -493,6 +494,29 @@ open class ClassAndMemberVisitor(
}
}
/**
* Finish visiting this method, writing any new method body byte-code
* if we haven't written it already. This would (presumably) only happen
* for methods that previously had no body, e.g. native methods.
*/
override fun visitEnd() {
tryReplaceMethodBody()
super.visitEnd()
}
private fun tryReplaceMethodBody() {
if (method.body.isNotEmpty() && (mv != null)) {
EmitterModule(mv).apply {
for (body in method.body) {
body(this)
}
}
mv.visitMaxs(-1, -1)
mv.visitEnd()
mv = null
}
}
/**
* Helper function used to streamline the access to an instruction and to catch any related processing errors.
*/
@ -517,7 +541,7 @@ open class ClassAndMemberVisitor(
* Visitor used to traverse and analyze a field.
*/
private inner class FieldVisitorImpl(
targetVisitor: FieldVisitor?
targetVisitor: FieldVisitor
) : FieldVisitor(API_VERSION, targetVisitor) {
/**

View File

@ -1,5 +1,8 @@
package net.corda.djvm.analysis
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.code.asResourcePath
/**
* Functionality for resolving the class name of a sandboxable class.
*
@ -32,12 +35,12 @@ class ClassResolver(
*/
fun resolve(name: String): String {
return when {
name.startsWith("[") && name.endsWith(";") -> {
name.startsWith('[') && name.endsWith(';') -> {
complexArrayTypeRegex.replace(name) {
"${it.groupValues[1]}L${resolveName(it.groupValues[2])};"
}
}
name.startsWith("[") && !name.endsWith(";") -> name
name.startsWith('[') && !name.endsWith(';') -> name
else -> resolveName(name)
}
}
@ -46,7 +49,7 @@ class ClassResolver(
* Resolve the class name from a fully qualified normalized name.
*/
fun resolveNormalized(name: String): String {
return resolve(name.replace('.', '/')).replace('/', '.')
return resolve(name.asResourcePath).asPackagePath
}
/**
@ -96,7 +99,7 @@ class ClassResolver(
* Reverse the resolution of a class name from a fully qualified normalized name.
*/
fun reverseNormalized(name: String): String {
return reverse(name.replace('.', '/')).replace('/', '.')
return reverse(name.asResourcePath).asPackagePath
}
/**

View File

@ -117,7 +117,8 @@ open class Whitelist private constructor(
"^java/lang/Throwable(\\..*)?$".toRegex(),
"^java/lang/Void(\\..*)?$".toRegex(),
"^java/lang/.*Error(\\..*)?$".toRegex(),
"^java/lang/.*Exception(\\..*)?$".toRegex()
"^java/lang/.*Exception(\\..*)?$".toRegex(),
"^java/lang/reflect/Array(\\..*)?$".toRegex()
)
/**

View File

@ -20,7 +20,7 @@ class ClassMutator(
private val configuration: AnalysisConfiguration,
private val definitionProviders: List<DefinitionProvider> = emptyList(),
private val emitters: List<Emitter> = emptyList()
) : ClassAndMemberVisitor(classVisitor, configuration = configuration) {
) : ClassAndMemberVisitor(configuration, classVisitor) {
/**
* Tracks whether any modifications have been applied to any of the processed class(es) and pertinent members.
@ -82,7 +82,8 @@ class ClassMutator(
*/
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
val context = EmitterContext(currentAnalysisContext(), configuration, emitter)
Processor.processEntriesOfType<Emitter>(emitters, analysisContext.messages) {
// We need to apply the tracing emitters before the non-tracing ones.
Processor.processEntriesOfType<Emitter>(emitters.sortedByDescending(Emitter::isTracer), analysisContext.messages) {
it.emit(context, instruction)
}
if (!emitter.emitDefaultInstruction || emitter.hasEmittedCustomCode) {

View File

@ -1,7 +1,9 @@
package net.corda.djvm.code
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import sandbox.net.corda.djvm.costing.RuntimeCostAccounter
/**
@ -29,7 +31,7 @@ class EmitterModule(
/**
* Emit instruction for creating a new object of type [typeName].
*/
fun new(typeName: String, opcode: Int = Opcodes.NEW) {
fun new(typeName: String, opcode: Int = NEW) {
hasEmittedCustomCode = true
methodVisitor.visitTypeInsn(opcode, typeName)
}
@ -38,7 +40,7 @@ class EmitterModule(
* Emit instruction for creating a new object of type [T].
*/
inline fun <reified T> new() {
new(T::class.java.name)
new(Type.getInternalName(T::class.java))
}
/**
@ -62,7 +64,7 @@ class EmitterModule(
*/
fun invokeStatic(owner: String, name: String, descriptor: String, isInterface: Boolean = false) {
hasEmittedCustomCode = true
methodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, owner, name, descriptor, isInterface)
methodVisitor.visitMethodInsn(INVOKESTATIC, owner, name, descriptor, isInterface)
}
/**
@ -70,14 +72,14 @@ class EmitterModule(
*/
fun invokeSpecial(owner: String, name: String, descriptor: String, isInterface: Boolean = false) {
hasEmittedCustomCode = true
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, owner, name, descriptor, isInterface)
methodVisitor.visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, isInterface)
}
/**
* Emit instruction for invoking a special method on class [T], e.g. a constructor or a method on a super-type.
*/
inline fun <reified T> invokeSpecial(name: String, descriptor: String, isInterface: Boolean = false) {
invokeSpecial(T::class.java.name, name, descriptor, isInterface)
invokeSpecial(Type.getInternalName(T::class.java), name, descriptor, isInterface)
}
/**
@ -85,7 +87,7 @@ class EmitterModule(
*/
fun pop() {
hasEmittedCustomCode = true
methodVisitor.visitInsn(Opcodes.POP)
methodVisitor.visitInsn(POP)
}
/**
@ -93,19 +95,40 @@ class EmitterModule(
*/
fun duplicate() {
hasEmittedCustomCode = true
methodVisitor.visitInsn(Opcodes.DUP)
methodVisitor.visitInsn(DUP)
}
/**
* Emit a sequence of instructions for instantiating and throwing an exception based on the provided message.
*/
fun throwError(message: String) {
fun <T : Throwable> throwException(exceptionType: Class<T>, message: String) {
hasEmittedCustomCode = true
new<java.lang.Exception>()
methodVisitor.visitInsn(Opcodes.DUP)
val exceptionName = Type.getInternalName(exceptionType)
new(exceptionName)
methodVisitor.visitInsn(DUP)
methodVisitor.visitLdcInsn(message)
invokeSpecial<java.lang.Exception>("<init>", "(Ljava/lang/String;)V")
methodVisitor.visitInsn(Opcodes.ATHROW)
invokeSpecial(exceptionName, "<init>", "(Ljava/lang/String;)V")
methodVisitor.visitInsn(ATHROW)
}
inline fun <reified T : Throwable> throwException(message: String) = throwException(T::class.java, message)
/**
* Emit instruction for returning from "void" method.
*/
fun returnVoid() {
methodVisitor.visitInsn(RETURN)
hasEmittedCustomCode = true
}
/**
* Emit instructions for a new line number.
*/
fun lineNumber(line: Int) {
val label = Label()
methodVisitor.visitLabel(label)
methodVisitor.visitLineNumber(line, label)
hasEmittedCustomCode = true
}
/**

View File

@ -0,0 +1,15 @@
@file:JvmName("Types")
package net.corda.djvm.code
import org.objectweb.asm.Type
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import sandbox.net.corda.djvm.rules.RuleViolationError
val ruleViolationError: String = Type.getInternalName(RuleViolationError::class.java)
val thresholdViolationError: String = Type.getInternalName(ThresholdViolationError::class.java)
/**
* Local extension method for normalizing a class name.
*/
val String.asPackagePath: String get() = this.replace('/', '.')
val String.asResourcePath: String get() = this.replace('.', '/')

View File

@ -1,6 +1,7 @@
package net.corda.djvm.costing
import net.corda.djvm.utilities.loggerFor
import sandbox.net.corda.djvm.costing.ThresholdViolationError
/**
* Cost metric to be used in a sandbox environment. The metric has a threshold and a mechanism for reporting violations.
@ -41,7 +42,7 @@ open class TypedRuntimeCost<T>(
if (thresholdPredicate(newValue)) {
val message = errorMessage(currentThread)
logger.error("Threshold breached; {}", message)
throw ThresholdViolationException(message)
throw ThresholdViolationError(message)
}
}

View File

@ -2,6 +2,7 @@ package net.corda.djvm.execution
import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.source.ClassSource
import java.util.function.Function
/**
* The executor is responsible for spinning up a deterministic, sandboxed environment and launching the referenced code
@ -12,14 +13,14 @@ import net.corda.djvm.source.ClassSource
* @param configuration The configuration of the sandbox.
*/
class DeterministicSandboxExecutor<TInput, TOutput>(
configuration: SandboxConfiguration = SandboxConfiguration.DEFAULT
configuration: SandboxConfiguration
) : SandboxExecutor<TInput, TOutput>(configuration) {
/**
* Short-hand for running a [SandboxedRunnable] in a sandbox by its type reference.
* Short-hand for running a [Function] in a sandbox by its type reference.
*/
inline fun <reified TRunnable : SandboxedRunnable<TInput, TOutput>> run(input: TInput):
ExecutionSummaryWithResult<TOutput?> {
inline fun <reified TRunnable : Function<in TInput, out TOutput>> run(input: TInput):
ExecutionSummaryWithResult<TOutput> {
return run(ClassSource.fromClassName(TRunnable::class.java.name), input)
}

View File

@ -1,6 +0,0 @@
package net.corda.djvm.execution
/**
* Functionality runnable by a sandbox executor, marked for discoverability.
*/
interface DiscoverableRunnable

View File

@ -1,7 +1,7 @@
package net.corda.djvm.execution
/**
* The execution profile of a [SandboxedRunnable] when run in a sandbox.
* The execution profile of a [java.util.function.Function] when run in a sandbox.
*
* @property allocationCostThreshold The threshold placed on allocations.
* @property invocationCostThreshold The threshold placed on invocations.

View File

@ -1,9 +1,9 @@
package net.corda.djvm.execution
/**
* The summary of the execution of a [SandboxedRunnable] in a sandbox. This class has no representation of the outcome,
* and is typically used when there has been a pre-mature exit from the sandbox, for instance, if an exception was
* thrown.
* The summary of the execution of a [java.util.function.Function] in a sandbox. This class has no representation of the
* outcome, and is typically used when there has been a pre-mature exit from the sandbox, for instance, if an exception
* was thrown.
*
* @property costs The costs accumulated when running the sandboxed code.
*/

View File

@ -1,7 +1,7 @@
package net.corda.djvm.execution
/**
* The summary of the execution of a [SandboxedRunnable] in a sandbox.
* The summary of the execution of a [java.util.function.Function] in a sandbox.
*
* @property result The outcome of the sandboxed operation.
* @see ExecutionSummary

View File

@ -2,7 +2,6 @@ package net.corda.djvm.execution
import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.SandboxRuntimeContext
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.messages.MessageCollection
import net.corda.djvm.rewiring.SandboxClassLoader
import net.corda.djvm.rewiring.SandboxClassLoadingException
@ -16,8 +15,7 @@ import kotlin.concurrent.thread
*/
class IsolatedTask(
private val identifier: String,
private val configuration: SandboxConfiguration,
private val context: AnalysisContext
private val configuration: SandboxConfiguration
) {
/**
@ -32,12 +30,12 @@ class IsolatedTask(
var exception: Throwable? = null
thread(name = threadName, isDaemon = true) {
logger.trace("Entering isolated runtime environment...")
SandboxRuntimeContext(configuration, context.inputClasses).use {
SandboxRuntimeContext(configuration).use {
output = try {
action(runnable)
} catch (ex: Throwable) {
logger.error("Exception caught in isolated runtime environment", ex)
exception = ex
exception = (ex as? LinkageError)?.cause ?: ex
null
}
costs = CostSummary(
@ -84,7 +82,7 @@ class IsolatedTask(
)
/**
* The class loader to use for loading the [SandboxedRunnable] and any referenced code in [SandboxExecutor.run].
* The class loader to use for loading the [java.util.function.Function] and any referenced code in [SandboxExecutor.run].
*/
val classLoader: SandboxClassLoader
get() = SandboxRuntimeContext.instance.classLoader

View File

@ -11,7 +11,6 @@ import net.corda.djvm.rewiring.SandboxClassLoadingException
import net.corda.djvm.source.ClassSource
import net.corda.djvm.utilities.loggerFor
import net.corda.djvm.validation.ReferenceValidationSummary
import net.corda.djvm.validation.ReferenceValidator
import java.lang.reflect.InvocationTargetException
/**
@ -22,7 +21,7 @@ import java.lang.reflect.InvocationTargetException
* @property configuration The configuration of sandbox.
*/
open class SandboxExecutor<in TInput, out TOutput>(
protected val configuration: SandboxConfiguration = SandboxConfiguration.DEFAULT
protected val configuration: SandboxConfiguration
) {
private val classModule = configuration.analysisConfiguration.classModule
@ -32,12 +31,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
private val whitelist = configuration.analysisConfiguration.whitelist
/**
* Module used to validate all traversable references before instantiating and executing a [SandboxedRunnable].
*/
private val referenceValidator = ReferenceValidator(configuration.analysisConfiguration)
/**
* Executes a [SandboxedRunnable] implementation.
* Executes a [java.util.function.Function] implementation.
*
* @param runnableClass The entry point of the sandboxed code to run.
* @param input The input to provide to the sandboxed environment.
@ -50,7 +44,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
open fun run(
runnableClass: ClassSource,
input: TInput
): ExecutionSummaryWithResult<TOutput?> {
): ExecutionSummaryWithResult<TOutput> {
// 1. We first do a breath first traversal of the class hierarchy, starting from the requested class.
// The branching is defined by class references from referencesFromLocation.
// 2. For each class we run validation against defined rules.
@ -63,22 +57,22 @@ open class SandboxExecutor<in TInput, out TOutput>(
// 6. For execution, we then load the top-level class, implementing the SandboxedRunnable interface, again and
// and consequently hit the cache. Once loaded, we can execute the code on the spawned thread, i.e., in an
// isolated environment.
logger.trace("Executing {} with input {}...", runnableClass, input)
logger.debug("Executing {} with input {}...", runnableClass, input)
// TODO Class sources can be analyzed in parallel, although this require making the analysis context thread-safe
// To do so, one could start by batching the first X classes from the class sources and analyse each one in
// parallel, caching any intermediate state and subsequently process enqueued sources in parallel batches as well.
// Note that this would require some rework of the [IsolatedTask] and the class loader to bypass the limitation
// of caching and state preserved in thread-local contexts.
val classSources = listOf(runnableClass)
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, classSources)
val result = IsolatedTask(runnableClass.qualifiedClassName, configuration, context).run {
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask(runnableClass.qualifiedClassName, configuration).run {
validate(context, classLoader, classSources)
val loadedClass = classLoader.loadClassAndBytes(runnableClass, context)
val instance = loadedClass.type.newInstance()
val method = loadedClass.type.getMethod("run", Any::class.java)
val method = loadedClass.type.getMethod("apply", Any::class.java)
try {
@Suppress("UNCHECKED_CAST")
method.invoke(instance, input) as? TOutput?
method.invoke(instance, input) as? TOutput
} catch (ex: InvocationTargetException) {
throw ex.targetException
}
@ -105,8 +99,8 @@ open class SandboxExecutor<in TInput, out TOutput>(
* @return A [LoadedClass] with the class' byte code, type and name.
*/
fun load(classSource: ClassSource): LoadedClass {
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, listOf(classSource))
val result = IsolatedTask("LoadClass", configuration, context).run {
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask("LoadClass", configuration).run {
classLoader.loadClassAndBytes(classSource, context)
}
return result.output ?: throw ClassNotFoundException(classSource.qualifiedClassName)
@ -125,8 +119,8 @@ open class SandboxExecutor<in TInput, out TOutput>(
@Throws(SandboxClassLoadingException::class)
fun validate(vararg classSources: ClassSource): ReferenceValidationSummary {
logger.trace("Validating {}...", classSources)
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration, classSources.toList())
val result = IsolatedTask("Validation", configuration, context).run {
val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration)
val result = IsolatedTask("Validation", configuration).run {
validate(context, classLoader, classSources.toList())
}
logger.trace("Validation of {} resulted in {}", classSources, result)
@ -172,10 +166,6 @@ open class SandboxExecutor<in TInput, out TOutput>(
}
failOnReportedErrorsInContext(context)
// Validate all references in class hierarchy before proceeding.
referenceValidator.validate(context, classLoader.analyzer)
failOnReportedErrorsInContext(context)
return ReferenceValidationSummary(context.classes, context.messages, context.classOrigins)
}
@ -185,7 +175,7 @@ open class SandboxExecutor<in TInput, out TOutput>(
private inline fun processClassQueue(
vararg elements: ClassSource, action: QueueProcessor<ClassSource>.(ClassSource, String) -> Unit
) {
QueueProcessor({ it.qualifiedClassName }, *elements).process { classSource ->
QueueProcessor(ClassSource::qualifiedClassName, *elements).process { classSource ->
val className = classResolver.reverse(classModule.getBinaryClassName(classSource.qualifiedClassName))
if (!whitelist.matches(className)) {
action(classSource, className)

View File

@ -1,19 +0,0 @@
package net.corda.djvm.execution
/**
* Functionality runnable by a sandbox executor.
*/
interface SandboxedRunnable<in TInput, out TOutput> : DiscoverableRunnable {
/**
* The entry point of the sandboxed functionality to be run.
*
* @param input The input to pass in to the entry point.
*
* @returns The output to pass back to the caller after the sandboxed code has finished running.
* @throws Exception The function can throw an exception, in which case the exception gets passed to the caller.
*/
@Throws(Exception::class)
fun run(input: TInput): TOutput?
}

View File

@ -53,7 +53,7 @@ class MemberFormatter(
* Check whether or not a signature is for a method.
*/
fun isMethod(abbreviatedSignature: String): Boolean {
return abbreviatedSignature.startsWith("(")
return abbreviatedSignature.startsWith('(')
}
/**

View File

@ -82,8 +82,8 @@ class ClassHierarchy(
return findAncestors(get(className)).plus(get(OBJECT_NAME))
.asSequence()
.filterNotNull()
.map { memberModule.getFromClass(it, memberName, signature) }
.firstOrNull { it != null }
.mapNotNull { memberModule.getFromClass(it, memberName, signature) }
.firstOrNull()
.apply {
logger.trace("Getting rooted member for {}.{}:{} yields {}", className, memberName, signature, this)
}

View File

@ -1,5 +1,8 @@
package net.corda.djvm.references
import net.corda.djvm.code.asPackagePath
import net.corda.djvm.code.asResourcePath
/**
* Class-specific functionality.
*/
@ -42,14 +45,12 @@ class ClassModule : AnnotationModule() {
/**
* Get the binary version of a class name.
*/
fun getBinaryClassName(name: String) =
normalizeClassName(name).replace('.', '/')
fun getBinaryClassName(name: String) = normalizeClassName(name).asResourcePath
/**
* Get the formatted version of a class name.
*/
fun getFormattedClassName(name: String) =
normalizeClassName(name).replace('/', '.')
fun getFormattedClassName(name: String) = normalizeClassName(name).asPackagePath
/**
* Get the short name of a class.

View File

@ -1,5 +1,13 @@
package net.corda.djvm.references
import net.corda.djvm.code.EmitterModule
/**
* Alias for a handler which will replace an entire
* method body with a block of byte-code.
*/
typealias MethodBody = (EmitterModule) -> Unit
/**
* Representation of a class member.
*
@ -11,6 +19,7 @@ package net.corda.djvm.references
* @property annotations The names of the annotations the member is attributed.
* @property exceptions The names of the exceptions that the member can throw.
* @property value The default value of a field.
* @property body One or more handlers to replace the method body with new byte-code.
*/
data class Member(
override val access: Int,
@ -20,5 +29,6 @@ data class Member(
val genericsDetails: String,
val annotations: MutableSet<String> = mutableSetOf(),
val exceptions: MutableSet<String> = mutableSetOf(),
val value: Any? = null
val value: Any? = null,
val body: List<MethodBody> = emptyList()
) : MemberInformation, EntityWithAccessFlag

View File

@ -33,14 +33,14 @@ class MemberModule : AnnotationModule() {
* Check if member is a field.
*/
fun isField(member: MemberInformation): Boolean {
return !member.signature.startsWith("(")
return !member.signature.startsWith('(')
}
/**
* Check if member is a method.
*/
fun isMethod(member: MemberInformation): Boolean {
return member.signature.startsWith("(")
return member.signature.startsWith('(')
}
/**

View File

@ -16,7 +16,11 @@ class ReferenceMap(
private val referencesPerLocation: MutableMap<String, MutableSet<ReferenceWithLocation>> = hashMapOf()
private var numberOfReferences = 0
/**
* The number of references in the map.
*/
var numberOfReferences = 0
private set
/**
* Add source location association to a target member.
@ -50,12 +54,6 @@ class ReferenceMap(
return referencesPerLocation.getOrElse(key(className, memberName, signature)) { emptySet() }
}
/**
* The number of member references in the map.
*/
val size: Int
get() = numberOfReferences
/**
* Get iterator for all the references in the map.
*/

View File

@ -27,7 +27,7 @@ open class ClassRewriter(
* @param context The context in which the class is being analyzed and processed.
*/
fun rewrite(reader: ClassReader, context: AnalysisContext): ByteCode {
logger.trace("Rewriting class {}...", reader.className)
logger.debug("Rewriting class {}...", reader.className)
val writer = SandboxClassWriter(reader, classLoader)
val classRemapper = ClassRemapper(writer, remapper)
val visitor = ClassMutator(

View File

@ -3,29 +3,31 @@ package net.corda.djvm.rewiring
import net.corda.djvm.SandboxConfiguration
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassAndMemberVisitor
import net.corda.djvm.code.asResourcePath
import net.corda.djvm.references.ClassReference
import net.corda.djvm.source.ClassSource
import net.corda.djvm.source.SourceClassLoader
import net.corda.djvm.utilities.loggerFor
import net.corda.djvm.validation.RuleValidator
/**
* Class loader that enables registration of rewired classes.
*
* @property configuration The configuration to use for the sandbox.
* @param configuration The configuration to use for the sandbox.
* @property context The context in which analysis and processing is performed.
*/
class SandboxClassLoader(
val configuration: SandboxConfiguration,
val context: AnalysisContext
) : ClassLoader() {
configuration: SandboxConfiguration,
private val context: AnalysisContext
) : ClassLoader(null) {
private val analysisConfiguration = configuration.analysisConfiguration
/**
* The instance used to validate that any loaded class complies with the specified rules.
*/
private val ruleValidator: RuleValidator = RuleValidator(
rules = configuration.rules,
configuration = configuration.analysisConfiguration
configuration = analysisConfiguration
)
/**
@ -37,12 +39,12 @@ class SandboxClassLoader(
/**
* Set of classes that should be left untouched due to pinning.
*/
private val pinnedClasses = configuration.analysisConfiguration.pinnedClasses
private val pinnedClasses = analysisConfiguration.pinnedClasses
/**
* Set of classes that should be left untouched due to whitelisting.
*/
private val whitelistedClasses = configuration.analysisConfiguration.whitelist
private val whitelistedClasses = analysisConfiguration.whitelist
/**
* Cache of loaded classes.
@ -52,10 +54,7 @@ class SandboxClassLoader(
/**
* The class loader used to find classes on the extended class path.
*/
private val supportingClassLoader = SourceClassLoader(
configuration.analysisConfiguration.classPath,
configuration.analysisConfiguration.classResolver
)
private val supportingClassLoader = analysisConfiguration.supportingClassLoader
/**
* The re-writer to use for registered classes.
@ -83,9 +82,9 @@ class SandboxClassLoader(
* @return The resulting <tt>Class</tt> object and its byte code representation.
*/
fun loadClassAndBytes(source: ClassSource, context: AnalysisContext): LoadedClass {
logger.trace("Loading class {}, origin={}...", source.qualifiedClassName, source.origin)
val name = configuration.analysisConfiguration.classResolver.reverseNormalized(source.qualifiedClassName)
val resolvedName = configuration.analysisConfiguration.classResolver.resolveNormalized(name)
logger.debug("Loading class {}, origin={}...", source.qualifiedClassName, source.origin)
val name = analysisConfiguration.classResolver.reverseNormalized(source.qualifiedClassName)
val resolvedName = analysisConfiguration.classResolver.resolveNormalized(name)
// Check if the class has already been loaded.
val loadedClass = loadedClasses[name]
@ -99,14 +98,14 @@ class SandboxClassLoader(
// Analyse the class if not matching the whitelist.
val readClassName = reader.className
if (!configuration.analysisConfiguration.whitelist.matches(readClassName)) {
if (!analysisConfiguration.whitelist.matches(readClassName)) {
logger.trace("Class {} does not match with the whitelist", source.qualifiedClassName)
logger.trace("Analyzing class {}...", source.qualifiedClassName)
analyzer.analyze(reader, context)
}
// Check if the class should be left untouched.
val qualifiedName = name.replace('.', '/')
val qualifiedName = name.asResourcePath
if (qualifiedName in pinnedClasses) {
logger.trace("Class {} is marked as pinned", source.qualifiedClassName)
val pinnedClasses = LoadedClass(
@ -146,7 +145,7 @@ class SandboxClassLoader(
context.recordClassOrigin(name, ClassReference(source.origin))
}
logger.trace("Loaded class {}, bytes={}, isModified={}",
logger.debug("Loaded class {}, bytes={}, isModified={}",
source.qualifiedClassName, byteCode.bytes.size, byteCode.isModified)
return classWithByteCode

View File

@ -1,5 +1,6 @@
package net.corda.djvm.rewiring
import net.corda.djvm.code.asPackagePath
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.ClassWriter.COMPUTE_FRAMES
@ -35,12 +36,12 @@ open class SandboxClassWriter(
type2 == OBJECT_NAME -> return type2
}
val class1 = try {
classLoader.loadClass(type1.replace('/', '.'))
classLoader.loadClass(type1.asPackagePath)
} catch (exception: Exception) {
throw TypeNotPresentException(type1, exception)
}
val class2 = try {
classLoader.loadClass(type2.replace('/', '.'))
classLoader.loadClass(type2.asPackagePath)
} catch (exception: Exception) {
throw TypeNotPresentException(type2, exception)
}

View File

@ -18,7 +18,7 @@ abstract class ClassRule : Rule {
*/
abstract fun validate(context: RuleContext, clazz: ClassRepresentation)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class itself.
if (clazz != null && member == null && instruction == null) {
validate(context, clazz)

View File

@ -18,7 +18,7 @@ abstract class InstructionRule : Rule {
*/
abstract fun validate(context: RuleContext, instruction: Instruction)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class member itself.
if (clazz != null && member != null && instruction != null) {
validate(context, instruction)

View File

@ -18,7 +18,7 @@ abstract class MemberRule : Rule {
*/
abstract fun validate(context: RuleContext, member: Member)
override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
final override fun validate(context: RuleContext, clazz: ClassRepresentation?, member: Member?, instruction: Instruction?) {
// Only run validation step if applied to the class member itself.
if (clazz != null && member != null && instruction == null) {
validate(context, member)

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.Instruction.Companion.OP_BREAKPOINT
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for invalid breakpoint instructions.
*/
class DisallowBreakpoints : InstructionRule() {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
fail("Disallowed breakpoint in method") given (instruction.operation == OP_BREAKPOINT)
}
}

View File

@ -1,36 +1,16 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.*
import net.corda.djvm.code.instructions.CodeLabel
import net.corda.djvm.code.instructions.TryCatchBlock
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
import org.objectweb.asm.Label
import sandbox.net.corda.djvm.costing.ThresholdViolationError
/**
* Rule that checks for attempted catches of [ThreadDeath], [ThresholdViolationException], [StackOverflowError],
* [OutOfMemoryError], [Error] or [Throwable].
* Rule that checks for attempted catches of [ThreadDeath], [ThresholdViolationError],
* [StackOverflowError], [OutOfMemoryError], [Error] or [Throwable].
*/
class DisallowCatchingBlacklistedExceptions : InstructionRule(), Emitter {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
if (instruction is TryCatchBlock) {
val typeName = context.classModule.getFormattedClassName(instruction.typeName)
warn("Injected runtime check for catch-block for type $typeName") given
(instruction.typeName in disallowedExceptionTypes)
fail("Disallowed catch of ThreadDeath exception") given
(instruction.typeName == threadDeathException)
fail("Disallowed catch of stack overflow exception") given
(instruction.typeName == stackOverflowException)
fail("Disallowed catch of out of memory exception") given
(instruction.typeName == outOfMemoryException)
fail("Disallowed catch of threshold violation exception") given
(instruction.typeName.endsWith(ThresholdViolationException::class.java.simpleName))
}
}
class DisallowCatchingBlacklistedExceptions : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
if (instruction is TryCatchBlock && instruction.typeName in disallowedExceptionTypes) {
@ -46,13 +26,27 @@ class DisallowCatchingBlacklistedExceptions : InstructionRule(), Emitter {
private fun isExceptionHandler(label: Label) = label in handlers
companion object {
private const val threadDeathException = "java/lang/ThreadDeath"
private const val stackOverflowException = "java/lang/StackOverflowError"
private const val outOfMemoryException = "java/lang/OutOfMemoryError"
// Any of [ThreadDeath]'s throwable super-classes need explicit checking.
private val disallowedExceptionTypes = setOf(
ruleViolationError,
thresholdViolationError,
/**
* These errors indicate that the JVM is failing,
* so don't allow these to be caught either.
*/
"java/lang/StackOverflowError",
"java/lang/OutOfMemoryError",
/**
* These are immediate super-classes for our explicit errors.
*/
"java/lang/VirtualMachineError",
"java/lang/ThreadDeath",
/**
* Any of [ThreadDeath] and [VirtualMachineError]'s throwable
* super-classes also need explicit checking.
*/
"java/lang/Throwable",
"java/lang/Error"
)

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.references.Member
import net.corda.djvm.rules.MemberRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for invalid use of finalizers.
*/
class DisallowFinalizerMethods : MemberRule() {
override fun validate(context: RuleContext, member: Member) = context.validate {
fail("Disallowed finalizer method") given ("${member.memberName}:${member.signature}" == "finalize:()V")
// TODO Make this rule simply erase the finalize() method and continue execution.
}
}

View File

@ -1,17 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.references.Member
import net.corda.djvm.rules.MemberRule
import net.corda.djvm.validation.RuleContext
import java.lang.reflect.Modifier
/**
* Rule that checks for invalid use of native methods.
*/
class DisallowNativeMethods : MemberRule() {
override fun validate(context: RuleContext, member: Member) = context.validate {
fail("Disallowed native method") given Modifier.isNative(member.access)
}
}

View File

@ -0,0 +1,42 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.MemberAccessInstruction
import net.corda.djvm.formatting.MemberFormatter
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
/**
* Some non-deterministic APIs belong to pinned classes and so cannot be stubbed out.
* Replace their invocations with exceptions instead.
*/
class DisallowNonDeterministicMethods : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
if (instruction is MemberAccessInstruction && isForbidden(instruction)) {
when (instruction.operation) {
INVOKEVIRTUAL -> {
throwException<RuleViolationError>("Disallowed reference to API; ${memberFormatter.format(instruction.member)}")
preventDefault()
}
}
}
}
private fun isClassReflection(instruction: MemberAccessInstruction): Boolean =
(instruction.owner == "java/lang/Class") && (
((instruction.memberName == "newInstance" && instruction.signature == "()Ljava/lang/Object;")
|| instruction.signature.contains("Ljava/lang/reflect/"))
)
private fun isObjectMonitor(instruction: MemberAccessInstruction): Boolean =
(instruction.signature == "()V" && (instruction.memberName == "notify" || instruction.memberName == "notifyAll" || instruction.memberName == "wait"))
|| (instruction.memberName == "wait" && (instruction.signature == "(J)V" || instruction.signature == "(JI)V"))
private fun isForbidden(instruction: MemberAccessInstruction): Boolean
= instruction.isMethod && (isClassReflection(instruction) || isObjectMonitor(instruction))
private val memberFormatter = MemberFormatter()
}

View File

@ -1,30 +0,0 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.instructions.MemberAccessInstruction
import net.corda.djvm.formatting.MemberFormatter
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
/**
* Rule that checks for illegal references to reflection APIs.
*/
class DisallowReflection : InstructionRule() {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
// TODO Enable controlled use of reflection APIs
if (instruction is MemberAccessInstruction) {
invalidReflectionUsage(instruction) given
("java/lang/Class" in instruction.owner && instruction.memberName == "newInstance")
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("java/lang/reflect/"))
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("java/lang/invoke/"))
invalidReflectionUsage(instruction) given (instruction.owner.startsWith("sun/"))
}
}
private fun RuleContext.invalidReflectionUsage(instruction: MemberAccessInstruction) =
this.fail("Disallowed reference to reflection API; ${memberFormatter.format(instruction.member)}")
private val memberFormatter = MemberFormatter()
}

View File

@ -0,0 +1,19 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.code.Instruction.Companion.OP_BREAKPOINT
/**
* Rule that deletes invalid breakpoint instructions.
*/
class IgnoreBreakpoints : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
when (instruction.operation) {
OP_BREAKPOINT -> preventDefault()
}
}
}

View File

@ -3,20 +3,13 @@ package net.corda.djvm.rules.implementation
import net.corda.djvm.code.Emitter
import net.corda.djvm.code.EmitterContext
import net.corda.djvm.code.Instruction
import net.corda.djvm.rules.InstructionRule
import net.corda.djvm.validation.RuleContext
import org.objectweb.asm.Opcodes.*
/**
* Rule that warns about the use of synchronized code blocks. This class also exposes an emitter that rewrites pertinent
* monitoring instructions to [POP]'s, as these replacements will remove the object references that [MONITORENTER] and
* [MONITOREXIT] anticipate to be on the stack.
* An emitter that rewrites monitoring instructions to [POP]s, as these replacements will remove
* the object references that [MONITORENTER] and [MONITOREXIT] anticipate to be on the stack.
*/
class IgnoreSynchronizedBlocks : InstructionRule(), Emitter {
override fun validate(context: RuleContext, instruction: Instruction) = context.validate {
inform("Stripped monitoring instruction") given (instruction.operation in setOf(MONITORENTER, MONITOREXIT))
}
class IgnoreSynchronizedBlocks : Emitter {
override fun emit(context: EmitterContext, instruction: Instruction) = context.emit {
when (instruction.operation) {

View File

@ -0,0 +1,35 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import java.lang.reflect.Modifier
/**
* Rule that replaces a finalize() method with a simple stub.
*/
class StubOutFinalizerMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member) = when {
/**
* Discard any other method body and replace with stub that just returns.
* Other [MemberDefinitionProvider]s are expected to append to this list
* and not replace its contents!
*/
isFinalizer(member) -> member.copy(body = listOf(::writeMethodBody))
else -> member
}
private fun writeMethodBody(emitter: EmitterModule): Unit = with(emitter) {
returnVoid()
}
/**
* No need to rewrite [Object.finalize] or [Enum.finalize]; ignore these.
*/
private fun isFinalizer(member: Member): Boolean
= member.memberName == "finalize" && member.signature == "()V"
&& !member.className.startsWith("java/lang/")
&& !Modifier.isAbstract(member.access)
}

View File

@ -0,0 +1,36 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
import java.lang.reflect.Modifier
/**
* Rule that replaces a native method with a stub that throws an exception.
*/
class StubOutNativeMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member) = when {
isNative(member) -> member.copy(
access = member.access and ACC_NATIVE.inv(),
body = member.body + if (isForStubbing(member)) ::writeStubMethodBody else ::writeExceptionMethodBody
)
else -> member
}
private fun writeExceptionMethodBody(emitter: EmitterModule): Unit = with(emitter) {
lineNumber(0)
throwException(RuleViolationError::class.java, "Native method has been deleted")
}
private fun writeStubMethodBody(emitter: EmitterModule): Unit = with(emitter) {
returnVoid()
}
private fun isForStubbing(member: Member): Boolean = member.signature == "()V" && member.memberName == "registerNatives"
private fun isNative(member: Member): Boolean = Modifier.isNative(member.access)
}

View File

@ -0,0 +1,35 @@
package net.corda.djvm.rules.implementation
import net.corda.djvm.analysis.AnalysisRuntimeContext
import net.corda.djvm.code.EmitterModule
import net.corda.djvm.code.MemberDefinitionProvider
import net.corda.djvm.references.Member
import org.objectweb.asm.Opcodes.*
import sandbox.net.corda.djvm.rules.RuleViolationError
/**
* Replace reflection APIs with stubs that throw exceptions. Only for unpinned classes.
*/
class StubOutReflectionMethods : MemberDefinitionProvider {
override fun define(context: AnalysisRuntimeContext, member: Member): Member = when {
isConcreteApi(member) && isReflection(member) -> member.copy(body = member.body + ::writeMethodBody)
else -> member
}
private fun writeMethodBody(emitter: EmitterModule): Unit = with(emitter) {
lineNumber(0)
throwException(RuleViolationError::class.java, "Disallowed reference to reflection API")
}
// The method must be public and with a Java implementation.
private fun isConcreteApi(member: Member): Boolean = member.access and (ACC_PUBLIC or ACC_ABSTRACT or ACC_NATIVE) == ACC_PUBLIC
private fun isReflection(member: Member): Boolean {
return member.className.startsWith("java/lang/reflect/")
|| member.className.startsWith("java/lang/invoke/")
|| member.className.startsWith("sun/reflect/")
|| member.className == "sun/misc/Unsafe"
|| member.className == "sun/misc/VM"
}
}

View File

@ -3,6 +3,7 @@ package net.corda.djvm.source
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassResolver
import net.corda.djvm.analysis.SourceLocation
import net.corda.djvm.code.asResourcePath
import net.corda.djvm.messages.Message
import net.corda.djvm.messages.Severity
import net.corda.djvm.rewiring.SandboxClassLoadingException
@ -17,18 +18,11 @@ import java.nio.file.Path
import java.nio.file.Paths
import kotlin.streams.toList
/**
* Customizable class loader that allows the user to explicitly specify additional JARs and directories to scan.
*
* @param paths The directories and explicit JAR files to scan.
* @property classResolver The resolver to use to derive the original name of a requested class.
* @property resolvedUrls The resolved URLs that get passed to the underlying class loader.
*/
open class SourceClassLoader(
paths: List<Path>,
private val classResolver: ClassResolver,
val resolvedUrls: Array<URL> = resolvePaths(paths)
) : URLClassLoader(resolvedUrls, SourceClassLoader::class.java.classLoader) {
abstract class AbstractSourceClassLoader(
paths: List<Path>,
private val classResolver: ClassResolver,
parent: ClassLoader?
) : URLClassLoader(resolvePaths(paths), parent) {
/**
* Open a [ClassReader] for the provided class name.
@ -36,7 +30,7 @@ open class SourceClassLoader(
fun classReader(
className: String, context: AnalysisContext, origin: String? = null
): ClassReader {
val originalName = classResolver.reverse(className.replace('.', '/'))
val originalName = classResolver.reverse(className.asResourcePath)
return try {
logger.trace("Opening ClassReader for class {}...", originalName)
getResourceAsStream("$originalName.class").use {
@ -71,16 +65,16 @@ open class SourceClassLoader(
return super.loadClass(originalName, resolve)
}
private companion object {
private val logger = loggerFor<SourceClassLoader>()
protected companion object {
@JvmStatic
protected val logger = loggerFor<SourceClassLoader>()
private fun resolvePaths(paths: List<Path>): Array<URL> {
return paths.map(this::expandPath).flatMap {
when {
!Files.exists(it) -> throw FileNotFoundException("File not found; $it")
Files.isDirectory(it) -> {
listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { it.toURL() }.toList()
listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { jar -> jar.toURL() }.toList()
}
Files.isReadable(it) && isJarFile(it) -> listOf(it.toURL())
else -> throw IllegalArgumentException("Expected JAR or class file, but found $it")
@ -100,11 +94,76 @@ open class SourceClassLoader(
private fun isJarFile(path: Path) = path.toString().endsWith(".jar", true)
private fun Path.toURL() = this.toUri().toURL()
private fun Path.toURL(): URL = this.toUri().toURL()
private val homeDirectory: Path
get() = Paths.get(System.getProperty("user.home"))
}
}
/**
* Class loader to manage an optional JAR of replacement Java APIs.
* @param bootstrapJar The location of the JAR containing the Java APIs.
* @param classResolver The resolver to use to derive the original name of a requested class.
*/
class BootstrapClassLoader(
bootstrapJar: Path,
classResolver: ClassResolver
) : AbstractSourceClassLoader(listOf(bootstrapJar), classResolver, null) {
/**
* Only search our own jars for the given resource.
*/
override fun getResource(name: String): URL? = findResource(name)
}
/**
* Customizable class loader that allows the user to explicitly specify additional JARs and directories to scan.
*
* @param paths The directories and explicit JAR files to scan.
* @property classResolver The resolver to use to derive the original name of a requested class.
* @property bootstrap The [BootstrapClassLoader] containing the Java APIs for the sandbox.
*/
class SourceClassLoader(
paths: List<Path>,
classResolver: ClassResolver,
private val bootstrap: BootstrapClassLoader? = null
) : AbstractSourceClassLoader(paths, classResolver, SourceClassLoader::class.java.classLoader) {
/**
* First check the bootstrap classloader, if we have one.
* Otherwise check our parent classloader, followed by
* the user-supplied jars.
*/
override fun getResource(name: String): URL? {
if (bootstrap != null) {
val resource = bootstrap.findResource(name)
if (resource != null) {
return resource
} else if (isJvmInternal(name)) {
logger.error("Denying request for actual {}", name)
return null
}
}
return parent?.getResource(name) ?: findResource(name)
}
/**
* Deny all requests for DJVM classes from any user-supplied jars.
*/
override fun findResource(name: String): URL? {
return if (name.startsWith("net/corda/djvm/")) null else super.findResource(name)
}
/**
* Does [name] exist within any of the packages reserved for Java itself?
*/
private fun isJvmInternal(name: String): Boolean = name.startsWith("java/")
|| name.startsWith("javax/")
|| name.startsWith("com/sun/")
|| name.startsWith("sun/")
|| name.startsWith("jdk/")
}

View File

@ -1,219 +0,0 @@
package net.corda.djvm.validation
import net.corda.djvm.analysis.AnalysisConfiguration
import net.corda.djvm.analysis.AnalysisContext
import net.corda.djvm.analysis.ClassAndMemberVisitor
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.formatting.MemberFormatter
import net.corda.djvm.messages.Message
import net.corda.djvm.messages.Severity
import net.corda.djvm.references.*
import net.corda.djvm.rewiring.SandboxClassLoadingException
import net.corda.djvm.utilities.loggerFor
/**
* Module used to validate all traversable references before instantiating and executing a [SandboxedRunnable].
*
* @param configuration The analysis configuration to use for the validation.
* @property memberFormatter Module with functionality for formatting class members.
*/
class ReferenceValidator(
private val configuration: AnalysisConfiguration,
private val memberFormatter: MemberFormatter = MemberFormatter()
) {
/**
* Container holding the current state of the validation.
*
* @property context The context in which references are to be validated.
* @property analyzer Underlying analyzer used for processing classes.
*/
private class State(
val context: AnalysisContext,
val analyzer: ClassAndMemberVisitor
)
/**
* Validate whether or not the classes in a class hierarchy can be safely instantiated and run in a sandbox by
* checking that all references are rooted in deterministic code.
*
* @param context The context in which the check should be made.
* @param analyzer Underlying analyzer used for processing classes.
*/
fun validate(context: AnalysisContext, analyzer: ClassAndMemberVisitor): ReferenceValidationSummary =
State(context, analyzer).let { state ->
logger.trace("Validating {} references across {} class(es)...",
context.references.size, context.classes.size)
context.references.process { validateReference(state, it) }
logger.trace("Reference validation completed; {} class(es) and {} message(s)",
context.references.size, context.classes.size)
ReferenceValidationSummary(state.context.classes, state.context.messages, state.context.classOrigins)
}
/**
* Construct a message from an invalid reference and its source location.
*/
private fun referenceToMessage(referenceWithLocation: ReferenceWithLocation): Message {
val (location, reference, description) = referenceWithLocation
val referenceMessage = when {
reference is ClassReference ->
"Invalid reference to class ${configuration.classModule.getFormattedClassName(reference.className)}"
reference is MemberReference && configuration.memberModule.isConstructor(reference) ->
"Invalid reference to constructor ${memberFormatter.format(reference)}"
reference is MemberReference && configuration.memberModule.isField(reference) ->
"Invalid reference to field ${memberFormatter.format(reference)}"
reference is MemberReference && configuration.memberModule.isMethod(reference) ->
"Invalid reference to method ${memberFormatter.format(reference)}"
else ->
"Invalid reference to $reference"
}
val message = if (description.isNotBlank()) {
"$referenceMessage, $description"
} else {
referenceMessage
}
return Message(message, Severity.ERROR, location)
}
/**
* Validate a reference made from a class or class member.
*/
private fun validateReference(state: State, reference: EntityReference) {
if (configuration.whitelist.matches(reference.className)) {
// The referenced class has been whitelisted - no need to go any further.
return
}
when (reference) {
is ClassReference -> {
logger.trace("Validating class reference {}", reference)
val clazz = getClass(state, reference.className)
val reason = when (clazz) {
null -> Reason(Reason.Code.NON_EXISTENT_CLASS)
else -> getReasonFromEntity(clazz)
}
if (reason != null) {
logger.trace("Recorded invalid class reference to {}; reason = {}", reference, reason)
state.context.messages.addAll(state.context.references.locationsFromReference(reference).map {
referenceToMessage(ReferenceWithLocation(it, reference, reason.description))
})
}
}
is MemberReference -> {
logger.trace("Validating member reference {}", reference)
// Ensure that the dependent class is loaded and analyzed
val clazz = getClass(state, reference.className)
val member = state.context.classes.getMember(
reference.className, reference.memberName, reference.signature
)
val reason = when {
clazz == null -> Reason(Reason.Code.NON_EXISTENT_CLASS)
member == null -> Reason(Reason.Code.NON_EXISTENT_MEMBER)
else -> getReasonFromEntity(state, member)
}
if (reason != null) {
logger.trace("Recorded invalid member reference to {}; reason = {}", reference, reason)
state.context.messages.addAll(state.context.references.locationsFromReference(reference).map {
referenceToMessage(ReferenceWithLocation(it, reference, reason.description))
})
}
}
}
}
/**
* Get a class from the class hierarchy by its binary name.
*/
private fun getClass(state: State, className: String, originClass: String? = null): ClassRepresentation? {
val name = if (configuration.classModule.isArray(className)) {
val arrayType = arrayTypeExtractor.find(className)?.groupValues?.get(1)
when (arrayType) {
null -> "java/lang/Object"
else -> arrayType
}
} else {
className
}
var clazz = state.context.classes[name]
if (clazz == null) {
logger.trace("Loading and analyzing referenced class {}...", name)
val origin = state.context.references
.locationsFromReference(ClassReference(name))
.map { it.className }
.firstOrNull() ?: originClass
state.analyzer.analyze(name, state.context, origin)
clazz = state.context.classes[name]
}
if (clazz == null) {
logger.warn("Failed to load class {}", name)
state.context.messages.add(Message("Referenced class not found; $name", Severity.ERROR))
}
clazz?.apply {
val ancestors = listOf(superClass) + interfaces
for (ancestor in ancestors.filter(String::isNotBlank)) {
getClass(state, ancestor, clazz.name)
}
}
return clazz
}
/**
* Check if a top-level class definition is considered safe or not.
*/
private fun isNonDeterministic(state: State, className: String): Boolean = when {
configuration.whitelist.matches(className) -> false
else -> {
try {
getClass(state, className)?.let {
isNonDeterministic(it)
} ?: true
} catch (exception: SandboxClassLoadingException) {
true // Failed to load the class, which means the class is non-deterministic.
}
}
}
/**
* Check if a top-level class definition is considered safe or not.
*/
private fun isNonDeterministic(clazz: ClassRepresentation) =
getReasonFromEntity(clazz) != null
/**
* Derive what reason to give to the end-user for an invalid class.
*/
private fun getReasonFromEntity(clazz: ClassRepresentation): Reason? = when {
configuration.whitelist.matches(clazz.name) -> null
configuration.whitelist.inNamespace(clazz.name) -> Reason(Reason.Code.NOT_WHITELISTED)
configuration.classModule.isNonDeterministic(clazz) -> Reason(Reason.Code.ANNOTATED)
else -> null
}
/**
* Derive what reason to give to the end-user for an invalid member.
*/
private fun getReasonFromEntity(state: State, member: Member): Reason? = when {
configuration.whitelist.matches(member.reference) -> null
configuration.whitelist.inNamespace(member.reference) -> Reason(Reason.Code.NOT_WHITELISTED)
configuration.memberModule.isNonDeterministic(member) -> Reason(Reason.Code.ANNOTATED)
else -> {
val invalidClasses = configuration.memberModule.findReferencedClasses(member)
.filter { isNonDeterministic(state, it) }
if (invalidClasses.isNotEmpty()) {
Reason(Reason.Code.INVALID_CLASS, invalidClasses.map {
configuration.classModule.getFormattedClassName(it)
})
} else {
null
}
}
}
private companion object {
private val logger = loggerFor<ReferenceValidator>()
private val arrayTypeExtractor = "^\\[*L([^;]+);$".toRegex()
}
}

View File

@ -21,9 +21,9 @@ import org.objectweb.asm.ClassVisitor
*/
class RuleValidator(
private val rules: List<Rule> = emptyList(),
configuration: AnalysisConfiguration = AnalysisConfiguration(),
configuration: AnalysisConfiguration,
classVisitor: ClassVisitor? = null
) : ClassAndMemberVisitor(classVisitor, configuration = configuration) {
) : ClassAndMemberVisitor(configuration, classVisitor) {
/**
* Apply the set of rules to the traversed class and record any violations.

View File

@ -2,7 +2,7 @@ package sandbox.net.corda.djvm.costing
import net.corda.djvm.SandboxRuntimeContext
import net.corda.djvm.costing.RuntimeCostSummary
import net.corda.djvm.costing.ThresholdViolationException
import org.objectweb.asm.Type
/**
* Class for keeping a tally on various runtime metrics, like number of jumps, allocations, invocations, etc. The
@ -24,7 +24,8 @@ object RuntimeCostAccounter {
/**
* The type name of the [RuntimeCostAccounter] class; referenced from instrumentors.
*/
const val TYPE_NAME: String = "sandbox/net/corda/djvm/costing/RuntimeCostAccounter"
@JvmField
val TYPE_NAME: String = Type.getInternalName(this::class.java)
/**
* Known / estimated allocation costs.
@ -35,14 +36,12 @@ object RuntimeCostAccounter {
)
/**
* Re-throw exception if it is of type [ThreadDeath] or [ThresholdViolationException].
* Re-throw exception if it is of type [ThreadDeath] or [VirtualMachineError].
*/
@JvmStatic
fun checkCatch(exception: Throwable) {
if (exception is ThreadDeath) {
throw exception
} else if (exception is ThresholdViolationException) {
throw exception
when (exception) {
is ThreadDeath, is VirtualMachineError -> throw exception
}
}

View File

@ -1,4 +1,4 @@
package net.corda.djvm.costing
package sandbox.net.corda.djvm.costing
/**
* Exception thrown when a sandbox threshold is violated. This will kill the current thread and consequently exit the
@ -6,6 +6,4 @@ package net.corda.djvm.costing
*
* @property message The description of the condition causing the problem.
*/
class ThresholdViolationException(
override val message: String
) : ThreadDeath()
class ThresholdViolationError(override val message: String) : ThreadDeath()

View File

@ -0,0 +1,10 @@
package sandbox.net.corda.djvm.rules
/**
* Exception thrown when a sandbox rule is violated at runtime.
* This will kill the current thread and consequently exit the
* sandbox.
*
* @property message The description of the condition causing the problem.
*/
class RuleViolationError(override val message: String) : ThreadDeath()

View File

@ -16,10 +16,16 @@ import net.corda.djvm.rules.Rule
import net.corda.djvm.source.ClassSource
import net.corda.djvm.utilities.Discovery
import net.corda.djvm.validation.RuleValidator
import org.junit.After
import org.junit.Assert.assertEquals
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Type
import java.lang.reflect.InvocationTargetException
import java.nio.file.Path
import java.nio.file.Paths
import kotlin.concurrent.thread
import kotlin.reflect.jvm.jvmName
abstract class TestBase {
@ -33,8 +39,10 @@ abstract class TestBase {
val BLANK = emptySet<Any>()
val DEFAULT = (ALL_RULES + ALL_EMITTERS + ALL_DEFINITION_PROVIDERS)
.toSet().distinctBy { it.javaClass }
val DEFAULT = (ALL_RULES + ALL_EMITTERS + ALL_DEFINITION_PROVIDERS).distinctBy(Any::javaClass)
val DETERMINISTIC_RT: Path = Paths.get(
System.getProperty("deterministic-rt.path") ?: throw AssertionError("deterministic-rt.path property not set"))
/**
* Get the full name of type [T].
@ -46,13 +54,18 @@ abstract class TestBase {
/**
* Default analysis configuration.
*/
val configuration = AnalysisConfiguration(Whitelist.MINIMAL)
val configuration = AnalysisConfiguration(Whitelist.MINIMAL, bootstrapJar = DETERMINISTIC_RT)
/**
* Default analysis context
*/
val context: AnalysisContext
get() = AnalysisContext.fromConfiguration(configuration, emptyList())
get() = AnalysisContext.fromConfiguration(configuration)
@After
fun destroy() {
configuration.close()
}
/**
* Short-hand for analysing and validating a class.
@ -62,14 +75,15 @@ abstract class TestBase {
noinline block: (RuleValidator.(AnalysisContext) -> Unit)
) {
val reader = ClassReader(T::class.java.name)
val configuration = AnalysisConfiguration(minimumSeverityLevel = minimumSeverityLevel)
val validator = RuleValidator(ALL_RULES, configuration)
val context = AnalysisContext.fromConfiguration(
configuration,
listOf(ClassSource.fromClassName(reader.className))
)
validator.analyze(reader, context)
block(validator, context)
AnalysisConfiguration(
minimumSeverityLevel = minimumSeverityLevel,
classPath = listOf(DETERMINISTIC_RT)
).use { analysisConfiguration ->
val validator = RuleValidator(ALL_RULES, analysisConfiguration)
val context = AnalysisContext.fromConfiguration(analysisConfiguration)
validator.analyze(reader, context)
block(validator, context)
}
}
/**
@ -113,27 +127,26 @@ abstract class TestBase {
}
}
var thrownException: Throwable? = null
Thread {
thread {
try {
val pinnedTestClasses = pinnedClasses.map(Type::getInternalName).toSet()
val analysisConfiguration = AnalysisConfiguration(
whitelist = whitelist,
additionalPinnedClasses = pinnedTestClasses,
minimumSeverityLevel = minimumSeverityLevel
)
SandboxRuntimeContext(SandboxConfiguration.of(
executionProfile, rules, emitters, definitionProviders, enableTracing, analysisConfiguration
), classSources).use {
assertThat(runtimeCosts).areZero()
action(this)
AnalysisConfiguration(
whitelist = whitelist,
bootstrapJar = DETERMINISTIC_RT,
additionalPinnedClasses = pinnedTestClasses,
minimumSeverityLevel = minimumSeverityLevel
).use { analysisConfiguration ->
SandboxRuntimeContext(SandboxConfiguration.of(
executionProfile, rules, emitters, definitionProviders, enableTracing, analysisConfiguration
)).use {
assertThat(runtimeCosts).areZero()
action(this)
}
}
} catch (exception: Throwable) {
thrownException = exception
}
}.apply {
start()
join()
}
}.join()
throw thrownException ?: return
}
@ -145,8 +158,12 @@ abstract class TestBase {
/**
* Create a new instance of a class using the sandbox class loader.
*/
inline fun <reified T : Callable> SandboxRuntimeContext.newCallable() =
classLoader.loadClassAndBytes(ClassSource.fromClassName(T::class.java.name), context)
inline fun <reified T : Callable> SandboxRuntimeContext.newCallable(): LoadedClass = loadClass<T>()
inline fun <reified T : Any> SandboxRuntimeContext.loadClass(): LoadedClass = loadClass(T::class.jvmName)
fun SandboxRuntimeContext.loadClass(className: String): LoadedClass =
classLoader.loadClassAndBytes(ClassSource.fromClassName(className), context)
/**
* Run the entry-point of the loaded [Callable] class.
@ -164,6 +181,10 @@ abstract class TestBase {
/**
* Stub visitor.
*/
protected class Visitor : ClassVisitor(ClassAndMemberVisitor.API_VERSION)
protected class Writer : ClassWriter(COMPUTE_FRAMES or COMPUTE_MAXS) {
init {
assertEquals(ClassAndMemberVisitor.API_VERSION, api)
}
}
}

View File

@ -21,7 +21,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test
fun `can traverse classes`() {
val classesVisited = mutableSetOf<ClassRepresentation>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClass(clazz: ClassRepresentation): ClassRepresentation {
classesVisited.add(clazz)
return clazz
@ -47,7 +47,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test
fun `can traverse fields`() {
val membersVisited = mutableSetOf<Member>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitField(clazz: ClassRepresentation, field: Member): Member {
membersVisited.add(field)
return field
@ -77,7 +77,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test
fun `can traverse methods`() {
val membersVisited = mutableSetOf<Member>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitMethod(clazz: ClassRepresentation, method: Member): Member {
membersVisited.add(method)
return method
@ -102,7 +102,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test
fun `can traverse class annotations`() {
val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitClassAnnotation(clazz: ClassRepresentation, descriptor: String) {
annotations.add(descriptor)
}
@ -118,9 +118,21 @@ class ClassAndMemberVisitorTest : TestBase() {
private class TestClassWithAnnotations
@Test
fun `can traverse member annotations`() {
fun `cannot traverse member annotations when reading`() {
val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitMemberAnnotation(clazz: ClassRepresentation, member: Member, descriptor: String) {
annotations.add("${member.memberName}:$descriptor")
}
}
visitor.analyze<TestClassWithMemberAnnotations>(context)
assertThat(annotations).isEmpty()
}
@Test
fun `can traverse member annotations when writing`() {
val annotations = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor(configuration, Writer()) {
override fun visitMemberAnnotation(clazz: ClassRepresentation, member: Member, descriptor: String) {
annotations.add("${member.memberName}:$descriptor")
}
@ -146,7 +158,7 @@ class ClassAndMemberVisitorTest : TestBase() {
@Test
fun `can traverse class sources`() {
val sources = mutableSetOf<String>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitSource(clazz: ClassRepresentation, source: String) {
sources.add(source)
}
@ -160,9 +172,21 @@ class ClassAndMemberVisitorTest : TestBase() {
}
@Test
fun `can traverse instructions`() {
fun `does not traverse instructions when reading`() {
val instructions = mutableSetOf<Pair<Member, Instruction>>()
val visitor = object : ClassAndMemberVisitor() {
val visitor = object : ClassAndMemberVisitor(configuration, null) {
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
instructions.add(Pair(method, instruction))
}
}
visitor.analyze<TestClassWithCode>(context)
assertThat(instructions).isEmpty()
}
@Test
fun `can traverse instructions when writing`() {
val instructions = mutableSetOf<Pair<Member, Instruction>>()
val visitor = object : ClassAndMemberVisitor(configuration, Writer()) {
override fun visitInstruction(method: Member, emitter: EmitterModule, instruction: Instruction) {
instructions.add(Pair(method, instruction))
}

View File

@ -1,66 +0,0 @@
package net.corda.djvm.analysis
import net.corda.djvm.TestBase
import net.corda.djvm.execution.SandboxedRunnable
import net.corda.djvm.validation.ReferenceValidator
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class ReferenceValidatorTest : TestBase() {
private fun validator(whitelist: Whitelist = Whitelist.MINIMAL) =
ReferenceValidator(AnalysisConfiguration(whitelist))
@Test
fun `can validate when there are no references`() = analyze { context ->
analyze<EmptyRunnable>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class EmptyRunnable : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return null
}
}
@Test
fun `can validate when there are references`() = analyze { context ->
analyze<RunnableWithReferences>(context)
analyze<TestRandom>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class RunnableWithReferences : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return TestRandom().nextInt()
}
}
private class TestRandom {
external fun nextInt(): Int
}
@Test
fun `can validate when there are transient references`() = analyze { context ->
analyze<RunnableWithTransientReferences>(context)
analyze<ReferencedClass>(context)
analyze<TestRandom>(context)
val (_, messages) = validator().validate(context, this)
assertThat(messages.count).isEqualTo(0)
}
private class RunnableWithTransientReferences : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
return ReferencedClass().test()
}
}
private class ReferencedClass {
fun test(): Int {
return TestRandom().nextInt()
}
}
}

View File

@ -1,27 +1,31 @@
package net.corda.djvm.assertions
import net.corda.djvm.rewiring.LoadedClass
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.*
class AssertiveClassWithByteCode(private val loadedClass: LoadedClass) {
fun isSandboxed(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.type.name).startsWith("sandbox.")
assertThat(loadedClass.type.name).startsWith("sandbox.")
return this
}
fun hasNotBeenModified(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.byteCode.isModified)
assertThat(loadedClass.byteCode.isModified)
.`as`("Byte code has been modified")
.isEqualTo(false)
return this
}
fun hasBeenModified(): AssertiveClassWithByteCode {
Assertions.assertThat(loadedClass.byteCode.isModified)
.`as`("Byte code has been modified")
assertThat(loadedClass.byteCode.isModified)
.`as`("Byte code has not been modified")
.isEqualTo(true)
return this
}
fun hasClassName(className: String): AssertiveClassWithByteCode {
assertThat(loadedClass.type.name).isEqualTo(className)
return this
}
}

View File

@ -10,7 +10,7 @@ open class AssertiveReferenceMap(private val references: ReferenceMap) {
fun hasCount(count: Int): AssertiveReferenceMap {
val allReferences = references.joinToString("\n") { " - $it" }
Assertions.assertThat(references.size)
Assertions.assertThat(references.numberOfReferences)
.overridingErrorMessage("Expected $count reference(s), found:\n$allReferences")
.isEqualTo(count)
return this

View File

@ -21,7 +21,7 @@ class ClassMutatorTest : TestBase() {
}
}
val context = context
val mutator = ClassMutator(Visitor(), configuration, listOf(definitionProvider))
val mutator = ClassMutator(Writer(), configuration, listOf(definitionProvider))
mutator.analyze<TestClass>(context)
assertThat(hasProvidedDefinition).isTrue()
assertThat(context.classes.get<TestClass>().access or ACC_STRICT).isNotEqualTo(0)
@ -39,7 +39,7 @@ class ClassMutatorTest : TestBase() {
}
}
val context = context
val mutator = ClassMutator(Visitor(), configuration, listOf(definitionProvider))
val mutator = ClassMutator(Writer(), configuration, listOf(definitionProvider))
mutator.analyze<TestClassWithMembers>(context)
assertThat(hasProvidedDefinition).isTrue()
for (member in context.classes.get<TestClassWithMembers>().members.values) {
@ -67,7 +67,7 @@ class ClassMutatorTest : TestBase() {
}
}
val context = context
val mutator = ClassMutator(Visitor(), configuration, emitters = listOf(emitter))
val mutator = ClassMutator(Writer(), configuration, emitters = listOf(emitter))
mutator.analyze<TestClassWithMembers>(context)
assertThat(hasEmittedCode).isTrue()
assertThat(shouldPreventDefault).isTrue()

View File

@ -7,6 +7,7 @@ import org.junit.Test
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.NEW
import org.objectweb.asm.Type
class EmitterModuleTest : TestBase() {
@ -14,15 +15,15 @@ class EmitterModuleTest : TestBase() {
fun `can emit code to method body`() {
var hasEmittedTypeInstruction = false
val methodVisitor = object : MethodVisitor(ClassAndMemberVisitor.API_VERSION) {
override fun visitTypeInsn(opcode: Int, type: String?) {
if (opcode == NEW && type == java.lang.String::class.java.name) {
override fun visitTypeInsn(opcode: Int, type: String) {
if (opcode == NEW && type == Type.getInternalName(java.lang.String::class.java)) {
hasEmittedTypeInstruction = true
}
}
}
val visitor = object : ClassVisitor(ClassAndMemberVisitor.API_VERSION) {
override fun visitMethod(
access: Int, name: String?, descriptor: String?, signature: String?, exceptions: Array<out String>?
access: Int, name: String, descriptor: String, signature: String?, exceptions: Array<out String>?
): MethodVisitor {
return methodVisitor
}

View File

@ -3,12 +3,13 @@ package net.corda.djvm.costing
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
class RuntimeCostTest {
@Test
fun `can increment cost`() {
val cost = RuntimeCost(10, { "failed" })
val cost = RuntimeCost(10) { "failed" }
cost.increment(1)
assertThat(cost.value).isEqualTo(1)
}
@ -16,8 +17,8 @@ class RuntimeCostTest {
@Test
fun `cannot increment cost beyond threshold`() {
Thread {
val cost = RuntimeCost(10, { "failed in ${it.name}" })
assertThatExceptionOfType(ThresholdViolationException::class.java)
val cost = RuntimeCost(10) { "failed in ${it.name}" }
assertThatExceptionOfType(ThresholdViolationError::class.java)
.isThrownBy { cost.increment(11) }
.withMessage("failed in Foo")
assertThat(cost.value).isEqualTo(11)

View File

@ -6,13 +6,15 @@ import foo.bar.sandbox.toNumber
import net.corda.djvm.TestBase
import net.corda.djvm.analysis.Whitelist
import net.corda.djvm.assertions.AssertionExtensions.withProblem
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.rewiring.SandboxClassLoadingException
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import sandbox.net.corda.djvm.rules.RuleViolationError
import java.nio.file.Files
import java.util.*
import java.util.function.Function
class SandboxExecutorTest : TestBase() {
@ -24,8 +26,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo("sandbox")
}
class TestSandboxedRunnable : SandboxedRunnable<Int, String> {
override fun run(input: Int): String? {
class TestSandboxedRunnable : Function<Int, String> {
override fun apply(input: Int): String {
return "sandbox"
}
}
@ -42,8 +44,8 @@ class SandboxExecutorTest : TestBase() {
.withMessageContaining("Contract constraint violated")
}
class Contract : SandboxedRunnable<Transaction?, Unit> {
override fun run(input: Transaction?) {
class Contract : Function<Transaction?, Unit> {
override fun apply(input: Transaction?) {
throw IllegalArgumentException("Contract constraint violated")
}
}
@ -58,8 +60,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo(0xfed_c0de + 2)
}
class TestObjectHashCode : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestObjectHashCode : Function<Int, Int> {
override fun apply(input: Int): Int {
val obj = Object()
val hash1 = obj.hashCode()
val hash2 = obj.hashCode()
@ -76,8 +78,8 @@ class SandboxExecutorTest : TestBase() {
assertThat(result).isEqualTo(0xfed_c0de + 1)
}
class TestObjectHashCodeWithHierarchy : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestObjectHashCodeWithHierarchy : Function<Int, Int> {
override fun apply(input: Int): Int {
val obj = MyObject()
return obj.hashCode()
}
@ -91,9 +93,9 @@ class SandboxExecutorTest : TestBase() {
.withMessageContaining("terminated due to excessive use of looping")
}
class TestThresholdBreach : SandboxedRunnable<Int, Int> {
class TestThresholdBreach : Function<Int, Int> {
private var x = 0
override fun run(input: Int): Int? {
override fun apply(input: Int): Int {
for (i in 0..1_000_000) {
x += 1
}
@ -109,8 +111,8 @@ class SandboxExecutorTest : TestBase() {
.withCauseInstanceOf(StackOverflowError::class.java)
}
class TestStackOverflow : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestStackOverflow : Function<Int, Int> {
override fun apply(input: Int): Int {
return a()
}
@ -124,11 +126,12 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestKotlinMetaClasses>(0) }
.withMessageContaining("java/util/Random.<clinit>(): Disallowed reference to reflection API; sun.misc.Unsafe.getUnsafe()")
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to reflection API")
}
class TestKotlinMetaClasses : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestKotlinMetaClasses : Function<Int, Int> {
override fun apply(input: Int): Int {
val someNumber = testRandom()
return "12345".toNumber() * someNumber
}
@ -139,30 +142,32 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestNonDeterministicCode>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withProblem("java/util/Random.<clinit>(): Disallowed reference to reflection API; sun.misc.Unsafe.getUnsafe()")
.withCauseInstanceOf(RuleViolationError::class.java)
.withProblem("Disallowed reference to reflection API")
}
class TestNonDeterministicCode : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestNonDeterministicCode : Function<Int, Int> {
override fun apply(input: Int): Int {
return Random().nextInt()
}
}
@Test
fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT) {
TestCatchThreadDeath().apply {
assertThat(apply(0)).isEqualTo(1)
}
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThreadDeath>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Disallowed catch of ThreadDeath exception")
.withMessageContaining(TestCatchThreadDeath::class.java.simpleName)
.withCauseExactlyInstanceOf(ThreadDeath::class.java)
}
class TestCatchThreadDeath : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestCatchThreadDeath : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
0
throw ThreadDeath()
} catch (exception: ThreadDeath) {
1
}
@ -170,20 +175,46 @@ class SandboxExecutorTest : TestBase() {
}
@Test
fun `cannot execute runnable that catches ThresholdViolationException`() = sandbox(DEFAULT) {
fun `cannot execute runnable that catches ThresholdViolationError`() = sandbox(DEFAULT) {
TestCatchThresholdViolationError().apply {
assertThat(apply(0)).isEqualTo(1)
}
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThresholdViolationException>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Disallowed catch of threshold violation exception")
.withMessageContaining(TestCatchThresholdViolationException::class.java.simpleName)
.isThrownBy { contractExecutor.run<TestCatchThresholdViolationError>(0) }
.withCauseExactlyInstanceOf(ThresholdViolationError::class.java)
.withMessageContaining("Can't catch this!")
}
class TestCatchThresholdViolationException : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestCatchThresholdViolationError : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
0
} catch (exception: ThresholdViolationException) {
throw ThresholdViolationError("Can't catch this!")
} catch (exception: ThresholdViolationError) {
1
}
}
}
@Test
fun `cannot execute runnable that catches RuleViolationError`() = sandbox(DEFAULT) {
TestCatchRuleViolationError().apply {
assertThat(apply(0)).isEqualTo(1)
}
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchRuleViolationError>(0) }
.withCauseExactlyInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Can't catch this!")
}
class TestCatchRuleViolationError : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
throw RuleViolationError("Can't catch this!")
} catch (exception: RuleViolationError) {
1
}
}
@ -209,12 +240,12 @@ class SandboxExecutorTest : TestBase() {
fun `cannot catch ThreadDeath`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorAndThreadDeath>(3) }
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(3) }
.withCauseInstanceOf(ThreadDeath::class.java)
}
class TestCatchThrowableAndError : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestCatchThrowableAndError : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
when (input) {
1 -> throw Throwable()
@ -229,13 +260,27 @@ class SandboxExecutorTest : TestBase() {
}
}
class TestCatchThrowableErrorAndThreadDeath : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestCatchThrowableErrorsAndThreadDeath : Function<Int, Int> {
override fun apply(input: Int): Int {
return try {
when (input) {
1 -> throw Throwable()
2 -> throw Error()
3 -> throw ThreadDeath()
3 -> try {
throw ThreadDeath()
} catch (ex: ThreadDeath) {
3
}
4 -> try {
throw StackOverflowError("FAKE OVERFLOW!")
} catch (ex: StackOverflowError) {
4
}
5 -> try {
throw OutOfMemoryError("FAKE OOM!")
} catch (ex: OutOfMemoryError) {
5
}
else -> 0
}
} catch (exception: Error) {
@ -246,6 +291,24 @@ class SandboxExecutorTest : TestBase() {
}
}
@Test
fun `cannot catch stack-overflow error`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(4) }
.withCauseInstanceOf(StackOverflowError::class.java)
.withMessageContaining("FAKE OVERFLOW!")
}
@Test
fun `cannot catch out-of-memory error`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestCatchThrowableErrorsAndThreadDeath>(5) }
.withCauseInstanceOf(OutOfMemoryError::class.java)
.withMessageContaining("FAKE OOM!")
}
@Test
fun `cannot persist state across sessions`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
@ -256,8 +319,8 @@ class SandboxExecutorTest : TestBase() {
.isEqualTo(1)
}
class TestStatePersistence : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestStatePersistence : Function<Int, Int> {
override fun apply(input: Int): Int {
ReferencedClass.value += 1
return ReferencedClass.value
}
@ -274,11 +337,11 @@ class SandboxExecutorTest : TestBase() {
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestIO>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Files.walk(Path, Integer, FileVisitOption[]): Disallowed dynamic invocation in method")
.withMessageContaining("Class file not found; java/nio/file/Files.class")
}
class TestIO : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestIO : Function<Int, Int> {
override fun apply(input: Int): Int {
val file = Files.createTempFile("test", ".dat")
Files.newBufferedWriter(file).use {
it.write("Hello world!")
@ -292,14 +355,13 @@ class SandboxExecutorTest : TestBase() {
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestReflection>(0) }
.withCauseInstanceOf(SandboxClassLoadingException::class.java)
.withMessageContaining("Disallowed reference to reflection API")
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Class.newInstance()")
.withMessageContaining("java.lang.reflect.Method.invoke(Object, Object[])")
}
class TestReflection : SandboxedRunnable<Int, Int> {
override fun run(input: Int): Int? {
class TestReflection : Function<Int, Int> {
override fun apply(input: Int): Int {
val clazz = Object::class.java
val obj = clazz.newInstance()
val result = clazz.methods.first().invoke(obj)
@ -307,4 +369,150 @@ class SandboxExecutorTest : TestBase() {
}
}
@Test
fun `can load and execute code that uses notify()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(1) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.notify()")
}
@Test
fun `can load and execute code that uses notifyAll()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(2) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.notifyAll()")
}
@Test
fun `can load and execute code that uses wait()`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(3) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait()")
}
@Test
fun `can load and execute code that uses wait(long)`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(4) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait(Long)")
}
@Test
fun `can load and execute code that uses wait(long,int)`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestMonitors>(5) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Disallowed reference to API;")
.withMessageContaining("java.lang.Object.wait(Long, Integer)")
}
@Test
fun `code after forbidden APIs is intact`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, String>(configuration)
assertThat(contractExecutor.run<TestMonitors>(0).result)
.isEqualTo("unknown")
}
class TestMonitors : Function<Int, String> {
override fun apply(input: Int): String {
return synchronized(this) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val javaObject = this as java.lang.Object
when(input) {
1 -> {
javaObject.notify()
"notify"
}
2 -> {
javaObject.notifyAll()
"notifyAll"
}
3 -> {
javaObject.wait()
"wait"
}
4 -> {
javaObject.wait(100)
"wait(100)"
}
5 -> {
javaObject.wait(100, 10)
"wait(100, 10)"
}
else -> "unknown"
}
}
}
}
@Test
fun `can load and execute code that has a native method`() = sandbox(DEFAULT) {
assertThatExceptionOfType(UnsatisfiedLinkError::class.java)
.isThrownBy { TestNativeMethod().apply(0) }
.withMessageContaining("TestNativeMethod.evilDeeds()I")
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
assertThatExceptionOfType(SandboxException::class.java)
.isThrownBy { contractExecutor.run<TestNativeMethod>(0) }
.withCauseInstanceOf(RuleViolationError::class.java)
.withMessageContaining("Native method has been deleted")
}
class TestNativeMethod : Function<Int, Int> {
override fun apply(input: Int): Int {
return evilDeeds()
}
private external fun evilDeeds(): Int
}
@Test
fun `check arrays still work`() = sandbox(DEFAULT) {
val contractExecutor = DeterministicSandboxExecutor<Int, Array<Int>>(configuration)
contractExecutor.run<TestArray>(100).apply {
assertThat(result).isEqualTo(arrayOf(100))
}
}
class TestArray : Function<Int, Array<Int>> {
override fun apply(input: Int): Array<Int> {
return listOf(input).toTypedArray()
}
}
@Test
fun `can load and execute class that has finalize`() = sandbox(DEFAULT) {
assertThatExceptionOfType(UnsupportedOperationException::class.java)
.isThrownBy { TestFinalizeMethod().apply(100) }
.withMessageContaining("Very Bad Thing")
val contractExecutor = DeterministicSandboxExecutor<Int, Int>(configuration)
contractExecutor.run<TestFinalizeMethod>(100).apply {
assertThat(result).isEqualTo(100)
}
}
class TestFinalizeMethod : Function<Int, Int> {
override fun apply(input: Int): Int {
finalize()
return input
}
private fun finalize() {
throw UnsupportedOperationException("Very Bad Thing")
}
}
}

View File

@ -6,10 +6,11 @@ import foo.bar.sandbox.Empty
import foo.bar.sandbox.StrictFloat
import net.corda.djvm.TestBase
import net.corda.djvm.assertions.AssertionExtensions.assertThat
import net.corda.djvm.costing.ThresholdViolationException
import net.corda.djvm.execution.ExecutionProfile
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test
import sandbox.net.corda.djvm.costing.ThresholdViolationError
import java.nio.file.Paths
class ClassRewriterTest : TestBase() {
@ -45,7 +46,7 @@ class ClassRewriterTest : TestBase() {
val callable = newCallable<B>()
assertThat(callable).hasBeenModified()
assertThat(callable).isSandboxed()
assertThatExceptionOfType(ThresholdViolationException::class.java).isThrownBy {
assertThatExceptionOfType(ThresholdViolationError::class.java).isThrownBy {
callable.createAndInvoke()
}.withMessageContaining("terminated due to excessive use of looping")
assertThat(runtimeCosts)
@ -61,4 +62,44 @@ class ClassRewriterTest : TestBase() {
callable.createAndInvoke()
}
@Test
fun `can load a Java API that still exists in Java runtime`() = sandbox(DEFAULT) {
assertThat(loadClass<MutableList<*>>())
.hasClassName("sandbox.java.util.List")
.hasBeenModified()
}
@Test
fun `cannot load a Java API that was deleted from Java runtime`() = sandbox(DEFAULT) {
assertThatExceptionOfType(SandboxClassLoadingException::class.java)
.isThrownBy { loadClass<Paths>() }
.withMessageContaining("Class file not found; java/nio/file/Paths.class")
}
@Test
fun `load internal Sun class that still exists in Java runtime`() = sandbox(DEFAULT) {
assertThat(loadClass<sun.misc.Unsafe>())
.hasClassName("sandbox.sun.misc.Unsafe")
.hasBeenModified()
}
@Test
fun `cannot load internal Sun class that was deleted from Java runtime`() = sandbox(DEFAULT) {
assertThatExceptionOfType(SandboxClassLoadingException::class.java)
.isThrownBy { loadClass<sun.misc.Timer>() }
.withMessageContaining("Class file not found; sun/misc/Timer.class")
}
@Test
fun `can load local class`() = sandbox(DEFAULT) {
assertThat(loadClass<Example>())
.hasClassName("sandbox.net.corda.djvm.rewiring.ClassRewriterTest\$Example")
.hasBeenModified()
}
class Example : java.util.function.Function<Int, Int> {
override fun apply(input: Int): Int {
return input
}
}
}

View File

@ -9,19 +9,6 @@ import java.util.*
class ReferenceExtractorTest : TestBase() {
@Test
fun `can find method references`() = validate<A> { context ->
assertThat(context.references)
.hasClass("java/util/Random")
.withLocationCount(1)
.hasMember("java/lang/Object", "<init>", "()V")
.withLocationCount(1)
.hasMember("java/util/Random", "<init>", "()V")
.withLocationCount(1)
.hasMember("java/util/Random", "nextInt", "()I")
.withLocationCount(1)
}
class A : Callable {
override fun call() {
synchronized(this) {
@ -30,21 +17,6 @@ class ReferenceExtractorTest : TestBase() {
}
}
@Test
fun `can find field references`() = validate<B> { context ->
assertThat(context.references)
.hasMember(Type.getInternalName(B::class.java), "foo", "Ljava/lang/String;")
}
class B {
@JvmField
val foo: String = ""
fun test(): String {
return foo
}
}
@Test
fun `can find class references`() = validate<C> { context ->
assertThat(context.references)

View File

@ -42,8 +42,7 @@ class RuleValidatorTest : TestBase() {
assertThat(context.messages)
.hasErrorCount(0)
.hasWarningCount(0)
.hasInfoCount(1)
.withMessage("Stripped monitoring instruction")
.hasInfoCount(0)
.hasTraceCount(4)
.withMessage("Synchronization specifier will be ignored")
.withMessage("Strict floating-point arithmetic will be applied")

View File

@ -67,7 +67,7 @@ class SourceClassLoaderTest {
val (first, second) = this
val directory = first.parent
val classLoader = SourceClassLoader(listOf(directory), classResolver)
assertThat(classLoader.resolvedUrls).anySatisfy {
assertThat(classLoader.urLs).anySatisfy {
assertThat(it).isEqualTo(first.toUri().toURL())
}.anySatisfy {
assertThat(it).isEqualTo(second.toUri().toURL())