From 01799cfc2d04c12f7ee1b527eed00dc35a60c8bb Mon Sep 17 00:00:00 2001 From: Chris Rankin Date: Thu, 25 Oct 2018 11:12:20 +0100 Subject: [PATCH] ENT-1906: Allow sandboxes to share a parent classloader. (#4103) * Allow sandboxes to share a parent classloader. * Tidy up DJVM test code. * Add review fixes. * Declare SandboxClassLoadingException as a RuntimeException. This preserves the SandboxClassLoader.loadClass(String) contract from Java. * Also add extra test cases. --- djvm/build.gradle | 3 +- .../net/corda/djvm/tools/cli/ClassCommand.kt | 8 +- .../tools/cli/WhitelistGenerateCommand.kt | 2 +- .../main/java/sandbox/java/lang/Object.java | 4 +- .../main/java/sandbox/java/lang/String.java | 5 +- .../main/java/sandbox/java/lang/System.java | 11 +- .../net/corda/djvm/SandboxConfiguration.kt | 15 +- .../net/corda/djvm/SandboxRuntimeContext.kt | 22 ++- .../djvm/analysis/AnalysisConfiguration.kt | 126 +++++++++---- .../corda/djvm/execution/SandboxExecutor.kt | 12 +- .../corda/djvm/messages/MessageCollection.kt | 23 +++ .../corda/djvm/rewiring/SandboxClassLoader.kt | 169 ++++++++++++----- .../rewiring/SandboxClassLoadingException.kt | 4 +- .../corda/djvm/source/SourceClassLoader.kt | 140 ++++++++------ .../src/main/kotlin/sandbox/java/lang/DJVM.kt | 14 +- .../djvm/execution/SandboxEnumJavaTest.java | 12 +- .../execution/SandboxThrowableJavaTest.java | 5 +- .../net/corda/djvm/DJVMExceptionTest.kt | 104 +++++++++-- .../test/kotlin/net/corda/djvm/DJVMTest.kt | 117 +++++++----- .../test/kotlin/net/corda/djvm/TestBase.kt | 175 ++++++++++++++++-- .../test/kotlin/net/corda/djvm/Utilities.kt | 10 - .../djvm/assertions/AssertionExtensions.kt | 2 + .../assertions/AssertiveClassWithByteCode.kt | 5 + .../djvm/assertions/AssertiveDJVMObject.kt | 26 +++ .../corda/djvm/execution/SandboxEnumTest.kt | 10 +- .../djvm/execution/SandboxExecutorTest.kt | 78 ++++---- .../djvm/execution/SandboxThrowableTest.kt | 6 +- .../corda/djvm/rewiring/ClassRewriterTest.kt | 28 ++- 28 files changed, 819 insertions(+), 317 deletions(-) create mode 100644 djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveDJVMObject.kt diff --git a/djvm/build.gradle b/djvm/build.gradle index 1b33bdd3ae..64bd5bbd53 100644 --- a/djvm/build.gradle +++ b/djvm/build.gradle @@ -57,7 +57,8 @@ shadowJar { // we will generate better versions from deterministic-rt.jar. exclude 'sandbox/java/lang/Appendable.class' exclude 'sandbox/java/lang/CharSequence.class' - exclude 'sandbox/java/lang/Character\$*.class' + exclude 'sandbox/java/lang/Character\$Subset.class' + exclude 'sandbox/java/lang/Character\$Unicode*.class' exclude 'sandbox/java/lang/Comparable.class' exclude 'sandbox/java/lang/Enum.class' exclude 'sandbox/java/lang/Iterable.class' diff --git a/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/ClassCommand.kt b/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/ClassCommand.kt index e22d9c084f..e0fdb7f745 100644 --- a/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/ClassCommand.kt +++ b/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/ClassCommand.kt @@ -191,12 +191,14 @@ abstract class ClassCommand : CommandBase() { emitters = ignoreEmitters.emptyListIfTrueOtherwiseNull(), definitionProviders = if (ignoreDefinitionProviders) { emptyList() } else { Discovery.find() }, enableTracing = !disableTracing, - analysisConfiguration = AnalysisConfiguration( + analysisConfiguration = AnalysisConfiguration.createRoot( whitelist = whitelist, minimumSeverityLevel = level, - classPath = getClasspath(), analyzeAnnotations = analyzeAnnotations, - prefixFilters = prefixFilters.toList() + prefixFilters = prefixFilters.toList(), + sourceClassLoaderFactory = { classResolver, bootstrapClassLoader -> + SourceClassLoader(getClasspath(), classResolver, bootstrapClassLoader) + } ) ) } diff --git a/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/WhitelistGenerateCommand.kt b/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/WhitelistGenerateCommand.kt index b578dbffb3..85a8084654 100644 --- a/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/WhitelistGenerateCommand.kt +++ b/djvm/cli/src/main/kotlin/net/corda/djvm/tools/cli/WhitelistGenerateCommand.kt @@ -33,7 +33,7 @@ class WhitelistGenerateCommand : CommandBase() { override fun validateArguments() = paths.isNotEmpty() override fun handleCommand(): Boolean { - val entries = AnalysisConfiguration().use { configuration -> + val entries = AnalysisConfiguration.createRoot().use { configuration -> val entries = mutableListOf() val visitor = object : ClassAndMemberVisitor(configuration, null) { override fun visitClass(clazz: ClassRepresentation): ClassRepresentation { diff --git a/djvm/src/main/java/sandbox/java/lang/Object.java b/djvm/src/main/java/sandbox/java/lang/Object.java index 62ac16d4dd..bcceaecbf2 100644 --- a/djvm/src/main/java/sandbox/java/lang/Object.java +++ b/djvm/src/main/java/sandbox/java/lang/Object.java @@ -45,8 +45,8 @@ public class Object { private static java.lang.Object unwrap(java.lang.Object arg) { if (arg instanceof Object) { return ((Object) arg).fromDJVM(); - } else if (Object[].class.isAssignableFrom(arg.getClass())) { - return fromDJVM((Object[]) arg); + } else if (java.lang.Object[].class.isAssignableFrom(arg.getClass())) { + return fromDJVM((java.lang.Object[]) arg); } else { return arg; } diff --git a/djvm/src/main/java/sandbox/java/lang/String.java b/djvm/src/main/java/sandbox/java/lang/String.java index 476669bfe9..7d9165f3b8 100644 --- a/djvm/src/main/java/sandbox/java/lang/String.java +++ b/djvm/src/main/java/sandbox/java/lang/String.java @@ -1,5 +1,6 @@ package sandbox.java.lang; +import net.corda.djvm.SandboxRuntimeContext; import org.jetbrains.annotations.NotNull; import sandbox.java.nio.charset.Charset; import sandbox.java.util.Comparator; @@ -8,7 +9,6 @@ import sandbox.java.util.Locale; import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.lang.reflect.Constructor; -import java.util.Map; @SuppressWarnings("unused") public final class String extends Object implements Comparable, CharSequence, Serializable { @@ -24,7 +24,6 @@ public final class String extends Object implements Comparable, CharSequ private static final String TRUE = new String("true"); private static final String FALSE = new String("false"); - private static final Map INTERNAL = new java.util.HashMap<>(); private static final Constructor SHARED; static { @@ -335,7 +334,7 @@ public final class String extends Object implements Comparable, CharSequ return toDJVM(value.trim()); } - public String intern() { return INTERNAL.computeIfAbsent(value, s -> this); } + public String intern() { return (String) SandboxRuntimeContext.getInstance().intern(value, this); } public char[] toCharArray() { return value.toCharArray(); diff --git a/djvm/src/main/java/sandbox/java/lang/System.java b/djvm/src/main/java/sandbox/java/lang/System.java index 95525d0b50..b68a36d83c 100644 --- a/djvm/src/main/java/sandbox/java/lang/System.java +++ b/djvm/src/main/java/sandbox/java/lang/System.java @@ -1,20 +1,15 @@ package sandbox.java.lang; +import net.corda.djvm.SandboxRuntimeContext; + @SuppressWarnings({"WeakerAccess", "unused"}) public final class System extends Object { private System() {} - /* - * This class is duplicated into every sandbox, where everything is single-threaded. - */ - private static final java.util.Map objectHashCodes = new java.util.LinkedHashMap<>(); - private static int objectCounter = 0; - public static int identityHashCode(java.lang.Object obj) { int nativeHashCode = java.lang.System.identityHashCode(obj); - // TODO Instead of using a magic offset below, one could take in a per-context seed - return objectHashCodes.computeIfAbsent(nativeHashCode, i -> ++objectCounter + 0xfed_c0de); + return SandboxRuntimeContext.getInstance().getHashCodeFor(nativeHashCode); } public static final String lineSeparator = String.toDJVM("\n"); diff --git a/djvm/src/main/kotlin/net/corda/djvm/SandboxConfiguration.kt b/djvm/src/main/kotlin/net/corda/djvm/SandboxConfiguration.kt index ffd233df25..bacc8670b0 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/SandboxConfiguration.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/SandboxConfiguration.kt @@ -5,6 +5,7 @@ import net.corda.djvm.code.DefinitionProvider import net.corda.djvm.code.EMIT_TRACING import net.corda.djvm.code.Emitter import net.corda.djvm.execution.ExecutionProfile +import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rules.Rule import net.corda.djvm.utilities.Discovery @@ -16,24 +17,28 @@ import net.corda.djvm.utilities.Discovery * @property definitionProviders The meta-data providers to apply to class and member definitions. * @property executionProfile The execution profile to use in the sandbox. * @property analysisConfiguration The configuration used in the analysis of classes. + * @property parentClassLoader The [SandboxClassLoader] that this sandbox will use as a parent. */ -@Suppress("unused") class SandboxConfiguration private constructor( val rules: List, val emitters: List, val definitionProviders: List, val executionProfile: ExecutionProfile, - val analysisConfiguration: AnalysisConfiguration + val analysisConfiguration: AnalysisConfiguration, + val parentClassLoader: SandboxClassLoader? ) { + @Suppress("unused") companion object { /** * Default configuration for the deterministic sandbox. */ + @JvmField val DEFAULT = SandboxConfiguration.of() /** * Configuration with no emitters, rules, meta-data providers or runtime thresholds. */ + @JvmField val EMPTY = SandboxConfiguration.of( ExecutionProfile.UNLIMITED, emptyList(), emptyList(), emptyList() ) @@ -47,7 +52,8 @@ class SandboxConfiguration private constructor( emitters: List? = null, definitionProviders: List = Discovery.find(), enableTracing: Boolean = true, - analysisConfiguration: AnalysisConfiguration = AnalysisConfiguration() + analysisConfiguration: AnalysisConfiguration = AnalysisConfiguration.createRoot(), + parentClassLoader: SandboxClassLoader? = null ) = SandboxConfiguration( executionProfile = profile, rules = rules, @@ -55,7 +61,8 @@ class SandboxConfiguration private constructor( enableTracing || it.priority > EMIT_TRACING }, definitionProviders = definitionProviders, - analysisConfiguration = analysisConfiguration + analysisConfiguration = analysisConfiguration, + parentClassLoader = parentClassLoader ) } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/SandboxRuntimeContext.kt b/djvm/src/main/kotlin/net/corda/djvm/SandboxRuntimeContext.kt index d717c9074e..9bfe692244 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/SandboxRuntimeContext.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/SandboxRuntimeContext.kt @@ -1,6 +1,5 @@ package net.corda.djvm -import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.costing.RuntimeCostSummary import net.corda.djvm.rewiring.SandboxClassLoader @@ -14,16 +13,27 @@ class SandboxRuntimeContext(val configuration: SandboxConfiguration) { /** * The class loader to use inside the sandbox. */ - val classLoader: SandboxClassLoader = SandboxClassLoader( - configuration, - AnalysisContext.fromConfiguration(configuration.analysisConfiguration) - ) + val classLoader: SandboxClassLoader = SandboxClassLoader.createFor(configuration) /** * A summary of the currently accumulated runtime costs (for, e.g., memory allocations, invocations, etc.). */ val runtimeCosts = RuntimeCostSummary(configuration.executionProfile) + private val hashCodes: MutableMap = mutableMapOf() + private var objectCounter: Int = 0 + + // TODO Instead of using a magic offset below, one could take in a per-context seed + fun getHashCodeFor(nativeHashCode: Int): Int { + return hashCodes.computeIfAbsent(nativeHashCode) { ++objectCounter + MAGIC_HASH_OFFSET } + } + + private val internStrings: MutableMap = mutableMapOf() + + fun intern(key: String, value: Any): Any { + return internStrings.computeIfAbsent(key) { value } + } + /** * Run a set of actions within the provided sandbox context. */ @@ -39,10 +49,12 @@ class SandboxRuntimeContext(val configuration: SandboxConfiguration) { companion object { private val threadLocalContext = ThreadLocal() + private const val MAGIC_HASH_OFFSET = 0xfed_c0de /** * When called from within a sandbox, this returns the context for the current sandbox thread. */ + @JvmStatic var instance: SandboxRuntimeContext get() = threadLocalContext.get() ?: throw IllegalStateException("SandboxContext has not been initialized before use") diff --git a/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt b/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt index ff71421a54..f17059dd18 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/analysis/AnalysisConfiguration.kt @@ -8,6 +8,7 @@ import net.corda.djvm.references.ClassModule import net.corda.djvm.references.Member import net.corda.djvm.references.MemberModule import net.corda.djvm.references.MethodBody +import net.corda.djvm.source.AbstractSourceClassLoader import net.corda.djvm.source.BootstrapClassLoader import net.corda.djvm.source.SourceClassLoader import org.objectweb.asm.Opcodes.* @@ -21,36 +22,36 @@ import java.nio.file.Path * The configuration to use for an analysis. * * @property whitelist The whitelist of class names. - * @param additionalPinnedClasses Classes that have already been declared in the sandbox namespace and that should be - * made available inside the sandboxed environment. + * @property pinnedClasses Classes that have already been declared in the sandbox namespace and that should be + * made available inside the sandboxed environment. These classes belong to the application + * classloader and so are shared across all sandboxes. + * @property classResolver Functionality used to resolve the qualified name and relevant information about a class. + * @property exceptionResolver Resolves the internal names of synthetic exception classes. * @property minimumSeverityLevel The minimum severity level to log and report. - * @param classPath The extended class path to use for the analysis. - * @param bootstrapJar The location of a jar containing the Java APIs. * @property analyzeAnnotations Analyze annotations despite not being explicitly referenced. * @property prefixFilters Only record messages where the originating class name matches one of the provided prefixes. * If none are provided, all messages will be reported. * @property classModule Module for handling evolution of a class hierarchy during analysis. * @property memberModule Module for handling the specification and inspection of class members. + * @property bootstrapClassLoader Optional provider for the Java API classes. + * @property supportingClassLoader ClassLoader providing the classes to run inside the sandbox. + * @property isRootConfiguration Effectively, whether we are allowed to close [bootstrapClassLoader]. */ -class AnalysisConfiguration( - val whitelist: Whitelist = Whitelist.MINIMAL, - additionalPinnedClasses: Set = emptySet(), - val minimumSeverityLevel: Severity = Severity.WARNING, - classPath: List = emptyList(), - bootstrapJar: Path? = null, - val analyzeAnnotations: Boolean = false, - val prefixFilters: List = emptyList(), - val classModule: ClassModule = ClassModule(), - val memberModule: MemberModule = MemberModule() +class AnalysisConfiguration private constructor( + val whitelist: Whitelist, + val pinnedClasses: Set, + val classResolver: ClassResolver, + val exceptionResolver: ExceptionResolver, + val minimumSeverityLevel: Severity, + val analyzeAnnotations: Boolean, + val prefixFilters: List, + val classModule: ClassModule, + val memberModule: MemberModule, + private val bootstrapClassLoader: BootstrapClassLoader?, + val supportingClassLoader: AbstractSourceClassLoader, + private val isRootConfiguration: Boolean ) : Closeable { - /** - * Classes that have already been declared in the sandbox namespace and that should be made - * available inside the sandboxed environment. These classes belong to the application - * classloader and so are shared across all sandboxes. - */ - val pinnedClasses: Set = MANDATORY_PINNED_CLASSES + additionalPinnedClasses - /** * These interfaces are modified as they are mapped into the sandbox by * having their unsandboxed version "stitched in" as a super-interface. @@ -63,26 +64,39 @@ class AnalysisConfiguration( */ val stitchedClasses: Map> get() = STITCHED_CLASSES - /** - * Functionality used to resolve the qualified name and relevant information about a class. - */ - val classResolver: ClassResolver = ClassResolver(pinnedClasses, TEMPLATE_CLASSES, whitelist, SANDBOX_PREFIX) - - /** - * Resolves the internal names of synthetic exception classes. - */ - val exceptionResolver: ExceptionResolver = ExceptionResolver(JVM_EXCEPTIONS, pinnedClasses, SANDBOX_PREFIX) - - private val bootstrapClassLoader = bootstrapJar?.let { BootstrapClassLoader(it, classResolver) } - val supportingClassLoader = SourceClassLoader(classPath, classResolver, bootstrapClassLoader) - @Throws(IOException::class) override fun close() { supportingClassLoader.use { - bootstrapClassLoader?.close() + if (isRootConfiguration) { + bootstrapClassLoader?.close() + } } } + /** + * Creates a child [AnalysisConfiguration] with this instance as its parent. + * The child inherits the same [whitelist], [pinnedClasses] and [bootstrapClassLoader]. + */ + fun createChild( + classPaths: List = emptyList(), + newMinimumSeverityLevel: Severity? + ): AnalysisConfiguration { + return AnalysisConfiguration( + whitelist = whitelist, + pinnedClasses = pinnedClasses, + classResolver = classResolver, + exceptionResolver = exceptionResolver, + minimumSeverityLevel = newMinimumSeverityLevel ?: minimumSeverityLevel, + analyzeAnnotations = analyzeAnnotations, + prefixFilters = prefixFilters, + classModule = classModule, + memberModule = memberModule, + bootstrapClassLoader = bootstrapClassLoader, + supportingClassLoader = SourceClassLoader(classPaths, classResolver, bootstrapClassLoader), + isRootConfiguration = false + ) + } + fun isTemplateClass(className: String): Boolean = className in TEMPLATE_CLASSES fun isPinnedClass(className: String): Boolean = className in pinnedClasses @@ -107,7 +121,7 @@ class AnalysisConfiguration( /** * These classes will be duplicated into every sandbox's - * classloader. + * parent classloader. */ private val TEMPLATE_CLASSES: Set = setOf( java.lang.Boolean::class.java, @@ -131,6 +145,7 @@ class AnalysisConfiguration( ).sandboxed() + setOf( "sandbox/Task", "sandbox/TaskTypes", + "sandbox/java/lang/Character\$Cache", "sandbox/java/lang/DJVM", "sandbox/java/lang/DJVMException", "sandbox/java/lang/DJVMThrowableWrapper", @@ -139,8 +154,8 @@ class AnalysisConfiguration( ) /** - * These are thrown by the JVM itself, and so - * we need to handle them without wrapping them. + * These exceptions are thrown by the JVM itself, and + * so we need to handle them without wrapping them. * * Note that this set is closed, i.e. every one * of these exceptions' [Throwable] super classes @@ -271,6 +286,41 @@ class AnalysisConfiguration( private fun Set>.sandboxed(): Set = map(Companion::sandboxed).toSet() private fun Iterable.mapByClassName(): Map> = groupBy(Member::className).mapValues(Map.Entry>::value) + + /** + * @see [AnalysisConfiguration] + */ + fun createRoot( + whitelist: Whitelist = Whitelist.MINIMAL, + additionalPinnedClasses: Set = emptySet(), + minimumSeverityLevel: Severity = Severity.WARNING, + analyzeAnnotations: Boolean = false, + prefixFilters: List = emptyList(), + classModule: ClassModule = ClassModule(), + memberModule: MemberModule = MemberModule(), + bootstrapClassLoader: BootstrapClassLoader? = null, + sourceClassLoaderFactory: (ClassResolver, BootstrapClassLoader?) -> AbstractSourceClassLoader = { classResolver, bootstrapCL -> + SourceClassLoader(emptyList(), classResolver, bootstrapCL) + } + ): AnalysisConfiguration { + val pinnedClasses = MANDATORY_PINNED_CLASSES + additionalPinnedClasses + val classResolver = ClassResolver(pinnedClasses, TEMPLATE_CLASSES, whitelist, SANDBOX_PREFIX) + + return AnalysisConfiguration( + whitelist = whitelist, + pinnedClasses = pinnedClasses, + classResolver = classResolver, + exceptionResolver = ExceptionResolver(JVM_EXCEPTIONS, pinnedClasses, SANDBOX_PREFIX), + minimumSeverityLevel = minimumSeverityLevel, + analyzeAnnotations = analyzeAnnotations, + prefixFilters = prefixFilters, + classModule = classModule, + memberModule = memberModule, + bootstrapClassLoader = bootstrapClassLoader, + supportingClassLoader = sourceClassLoaderFactory(classResolver, bootstrapClassLoader), + isRootConfiguration = true + ) + } } private open class MethodBuilder( diff --git a/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt b/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt index 245e00902d..b989cdf477 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/execution/SandboxExecutor.kt @@ -71,14 +71,14 @@ open class SandboxExecutor( // Load the "entry-point" task class into the sandbox. This task will marshall // the input and outputs between Java types and sandbox wrapper types. - val taskClass = Class.forName("sandbox.Task", false, classLoader) + val taskClass = classLoader.loadClass("sandbox.Task") // Create the user's task object inside the sandbox. - val runnable = classLoader.loadForSandbox(runnableClass, context).type.newInstance() + val runnable = classLoader.loadClassForSandbox(runnableClass).newInstance() // Fetch this sandbox's instance of Class so we can retrieve Task(Function) // and then instantiate the Task. - val functionClass = Class.forName("sandbox.java.util.function.Function", false, classLoader) + val functionClass = classLoader.loadClass("sandbox.java.util.function.Function") val task = taskClass.getDeclaredConstructor(functionClass).newInstance(runnable) // Execute the task... @@ -114,7 +114,7 @@ open class SandboxExecutor( fun load(classSource: ClassSource): LoadedClass { val context = AnalysisContext.fromConfiguration(configuration.analysisConfiguration) val result = IsolatedTask("LoadClass", configuration).run { - classLoader.loadForSandbox(classSource, context) + classLoader.copyEmpty(context).loadForSandbox(classSource) } return result.output ?: throw ClassNotFoundException(classSource.qualifiedClassName) } @@ -159,11 +159,13 @@ open class SandboxExecutor( ): ReferenceValidationSummary { processClassQueue(*classSources.toTypedArray()) { classSource, className -> val didLoad = try { - classLoader.loadForSandbox(classSource, context) + classLoader.copyEmpty(context).loadClassForSandbox(classSource) true } catch (exception: SandboxClassLoadingException) { // Continue; all warnings and errors are captured in [context.messages] false + } finally { + context.messages.acceptProvisional() } if (didLoad) { context.classes[className]?.apply { diff --git a/djvm/src/main/kotlin/net/corda/djvm/messages/MessageCollection.kt b/djvm/src/main/kotlin/net/corda/djvm/messages/MessageCollection.kt index 43f90fdd0d..63c53e871d 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/messages/MessageCollection.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/messages/MessageCollection.kt @@ -22,6 +22,7 @@ class MessageCollection( private val memberMessages = mutableMapOf>() + private val provisional = mutableListOf() private var cachedEntries: List? = null /** @@ -58,6 +59,28 @@ class MessageCollection( } } + /** + * Hold this message until we've decided whether or not it's real. + */ + fun provisionalAdd(message: Message) { + provisional.add(message) + } + + /** + * Discard all provisional messages. + */ + fun clearProvisional() { + provisional.clear() + } + + /** + * Accept all provisional messages. + */ + fun acceptProvisional() { + addAll(provisional) + clearProvisional() + } + /** * Get all recorded messages for a given class. */ diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt index 4dbeae7ab2..c1b3475534 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoader.kt @@ -1,6 +1,7 @@ package net.corda.djvm.rewiring import net.corda.djvm.SandboxConfiguration +import net.corda.djvm.analysis.AnalysisConfiguration import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.analysis.ClassAndMemberVisitor import net.corda.djvm.analysis.ExceptionResolver.Companion.getDJVMExceptionOwner @@ -8,30 +9,32 @@ import net.corda.djvm.analysis.ExceptionResolver.Companion.isDJVMException import net.corda.djvm.code.asPackagePath import net.corda.djvm.code.asResourcePath import net.corda.djvm.references.ClassReference +import net.corda.djvm.source.AbstractSourceClassLoader import net.corda.djvm.source.ClassSource import net.corda.djvm.utilities.loggerFor import net.corda.djvm.validation.RuleValidator +import org.objectweb.asm.Type /** * Class loader that enables registration of rewired classes. * - * @param configuration The configuration to use for the sandbox. + * @property analysisConfiguration The configuration to use for the analysis. + * @property ruleValidator The instance used to validate that any loaded class complies with the specified rules. + * @property supportingClassLoader The class loader used to find classes on the extended class path. + * @property rewriter The re-writer to use for registered classes. * @property context The context in which analysis and processing is performed. + * @param throwableClass This sandbox's definition of [sandbox.java.lang.Throwable]. + * @param parent This classloader's parent classloader. */ -class SandboxClassLoader( - configuration: SandboxConfiguration, - private val context: AnalysisContext -) : ClassLoader() { - - private val analysisConfiguration = configuration.analysisConfiguration - - /** - * The instance used to validate that any loaded class complies with the specified rules. - */ - private val ruleValidator: RuleValidator = RuleValidator( - rules = configuration.rules, - configuration = analysisConfiguration - ) +class SandboxClassLoader private constructor( + private val analysisConfiguration: AnalysisConfiguration, + private val ruleValidator: RuleValidator, + private val supportingClassLoader: AbstractSourceClassLoader, + private val rewriter: ClassRewriter, + private val context: AnalysisContext, + throwableClass: Class<*>?, + parent: ClassLoader? +) : ClassLoader(parent ?: getSystemClassLoader()) { /** * The analyzer used to traverse the class hierarchy. @@ -50,36 +53,65 @@ class SandboxClassLoader( private val loadedClasses = mutableMapOf() /** - * The class loader used to find classes on the extended class path. + * We need to load [sandbox.java.lang.Throwable] up front, so that we can + * identify sandboxed exception classes. */ - private val supportingClassLoader = analysisConfiguration.supportingClassLoader - - /** - * The re-writer to use for registered classes. - */ - private val rewriter: ClassRewriter = ClassRewriter(configuration, supportingClassLoader) - - /** - * We need to load this class up front, so that we can identify sandboxed exception classes. - */ - private val throwableClass: Class<*> - - init { - // Bootstrap the loading of the sandboxed Throwable class. + private val throwableClass: Class<*> = throwableClass ?: run { loadClassAndBytes(ClassSource.fromClassName("sandbox.java.lang.Object"), context) loadClassAndBytes(ClassSource.fromClassName("sandbox.java.lang.StackTraceElement"), context) - throwableClass = loadClassAndBytes(ClassSource.fromClassName("sandbox.java.lang.Throwable"), context).type + loadClassAndBytes(ClassSource.fromClassName("sandbox.java.lang.Throwable"), context).type } + /** + * Creates an empty [SandboxClassLoader] with exactly the same + * configuration as this one, but with the given [AnalysisContext]. + * @param newContext The [AnalysisContext] to use for the child classloader. + */ + fun copyEmpty(newContext: AnalysisContext) = SandboxClassLoader( + analysisConfiguration, + ruleValidator, + supportingClassLoader, + rewriter, + newContext, + throwableClass, + parent + ) + /** * Given a class name, provide its corresponding [LoadedClass] for the sandbox. + * This class may have been loaded by a parent classloader really. */ - fun loadForSandbox(name: String, context: AnalysisContext): LoadedClass { - return loadClassAndBytes(ClassSource.fromClassName(analysisConfiguration.classResolver.resolveNormalized(name)), context) + @Throws(ClassNotFoundException::class) + fun loadForSandbox(className: String): LoadedClass { + val sandboxClass = loadClassForSandbox(className) + val sandboxName = Type.getInternalName(sandboxClass) + var loader = this + while(true) { + val loaded = loader.loadedClasses[sandboxName] + if (loaded != null) { + return loaded + } + loader = loader.parent as? SandboxClassLoader ?: return LoadedClass(sandboxClass, UNMODIFIED) + } } - fun loadForSandbox(source: ClassSource, context: AnalysisContext): LoadedClass { - return loadForSandbox(source.qualifiedClassName, context) + @Throws(ClassNotFoundException::class) + fun loadForSandbox(source: ClassSource): LoadedClass { + return loadForSandbox(source.qualifiedClassName) + } + + private fun loadClassForSandbox(className: String): Class<*> { + val sandboxName = analysisConfiguration.classResolver.resolveNormalized(className) + return try { + loadClass(sandboxName) + } finally { + context.messages.acceptProvisional() + } + } + + @Throws(ClassNotFoundException::class) + fun loadClassForSandbox(source: ClassSource): Class<*> { + return loadClassForSandbox(source.qualifiedClassName) } /** @@ -95,10 +127,19 @@ class SandboxClassLoader( var clazz = findLoadedClass(name) if (clazz == null) { val source = ClassSource.fromClassName(name) - clazz = if (analysisConfiguration.isSandboxClass(source.internalClassName)) { - loadSandboxClass(source, context).type - } else { - super.loadClass(name, resolve) + val isSandboxClass = analysisConfiguration.isSandboxClass(source.internalClassName) + + if (!isSandboxClass || parent is SandboxClassLoader) { + try { + clazz = super.loadClass(name, resolve) + } catch (e: ClassNotFoundException) { + } catch (e: SandboxClassLoadingException) { + e.messages.clearProvisional() + } + } + + if (clazz == null && isSandboxClass) { + clazz = loadSandboxClass(source, context).type } } if (resolve) { @@ -107,15 +148,31 @@ class SandboxClassLoader( return clazz } + /** + * A sandboxed exception class cannot be thrown, and so we may also need to create a + * synthetic throwable wrapper for it. Or perhaps we've just been asked to load the + * synthetic wrapper class belonging to an exception that we haven't loaded yet? + * Either way, we need to load the sandboxed exception first so that we know what + * the synthetic wrapper's super-class needs to be. + */ private fun loadSandboxClass(source: ClassSource, context: AnalysisContext): LoadedClass { return if (isDJVMException(source.internalClassName)) { /** * We need to load a DJVMException's owner class before we can create - * its wrapper exception. And loading the owner should also create the - * wrapper class automatically. + * its wrapper exception. And loading the owner should then also create + * the wrapper class automatically. */ loadedClasses.getOrElse(source.internalClassName) { - loadSandboxClass(ClassSource.fromClassName(getDJVMExceptionOwner(source.qualifiedClassName)), context) + val exceptionOwner = ClassSource.fromClassName(getDJVMExceptionOwner(source.qualifiedClassName)) + if (!analysisConfiguration.isJvmException(exceptionOwner.internalClassName)) { + /** + * JVM Exceptions belong to the parent classloader, and so will never + * be found inside a child classloader. Which means we must not try to + * create a duplicate inside any child classloaders either. Hence we + * re-invoke [loadClass] which will delegate back to the parent. + */ + loadClass(exceptionOwner.qualifiedClassName, false) + } loadedClasses[source.internalClassName] } ?: throw ClassNotFoundException(source.qualifiedClassName) } else { @@ -171,6 +228,7 @@ class SandboxClassLoader( } // Check if any errors were found during analysis. + context.messages.acceptProvisional() if (context.messages.errorCount > 0) { logger.debug("Errors detected after analyzing class {}", request.qualifiedClassName) throw SandboxClassLoadingException(context) @@ -214,6 +272,10 @@ class SandboxClassLoader( } } + /** + * Check whether the synthetic throwable wrapper already + * exists for this exception, and create it if it doesn't. + */ private fun loadWrapperFor(throwable: Class<*>): LoadedClass { val className = analysisConfiguration.exceptionResolver.getThrowableName(throwable) return loadedClasses.getOrPut(className) { @@ -223,9 +285,30 @@ class SandboxClassLoader( } } - private companion object { + companion object { private val logger = loggerFor() private val UNMODIFIED = ByteCode(ByteArray(0), false) + + /** + * Factory function to create a [SandboxClassLoader]. + * @param configuration The [SandboxConfiguration] containing the classloader's configuration parameters. + */ + fun createFor(configuration: SandboxConfiguration): SandboxClassLoader { + val analysisConfiguration = configuration.analysisConfiguration + val supportingClassLoader = analysisConfiguration.supportingClassLoader + val parentClassLoader = configuration.parentClassLoader + + return SandboxClassLoader( + analysisConfiguration = analysisConfiguration, + supportingClassLoader = supportingClassLoader, + ruleValidator = RuleValidator(rules = configuration.rules, + configuration = analysisConfiguration), + rewriter = ClassRewriter(configuration, supportingClassLoader), + context = AnalysisContext.fromConfiguration(analysisConfiguration), + throwableClass = parentClassLoader?.throwableClass, + parent = parentClassLoader + ) + } } } diff --git a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoadingException.kt b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoadingException.kt index faa5f95282..cdadbda37f 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoadingException.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/rewiring/SandboxClassLoadingException.kt @@ -19,7 +19,7 @@ class SandboxClassLoadingException( val messages: MessageCollection = context.messages, val classes: ClassHierarchy = context.classes, val classOrigins: Map> = context.classOrigins -) : Exception("Failed to load class") { +) : RuntimeException("Failed to load class") { /** * The detailed description of the exception. @@ -28,7 +28,7 @@ class SandboxClassLoadingException( get() = StringBuilder().apply { appendln(super.message) for (message in messages.sorted().map(Message::toString).distinct()) { - appendln(" - $message") + append(" - ").appendln(message) } }.toString().trimEnd('\r', '\n') diff --git a/djvm/src/main/kotlin/net/corda/djvm/source/SourceClassLoader.kt b/djvm/src/main/kotlin/net/corda/djvm/source/SourceClassLoader.kt index fbc0f1b0f0..a644d550ac 100644 --- a/djvm/src/main/kotlin/net/corda/djvm/source/SourceClassLoader.kt +++ b/djvm/src/main/kotlin/net/corda/djvm/source/SourceClassLoader.kt @@ -1,5 +1,7 @@ +@file:JvmName("SourceClassLoaderTools") package net.corda.djvm.source +import net.corda.djvm.analysis.AnalysisConfiguration.Companion.SANDBOX_PREFIX import net.corda.djvm.analysis.AnalysisContext import net.corda.djvm.analysis.ClassResolver import net.corda.djvm.analysis.ExceptionResolver.Companion.getDJVMExceptionOwner @@ -33,19 +35,23 @@ abstract class AbstractSourceClassLoader( className: String, context: AnalysisContext, origin: String? = null ): ClassReader { val originalName = classResolver.reverse(className.asResourcePath) + + fun throwClassLoadingError(): Nothing { + context.messages.provisionalAdd(Message( + message ="Class file not found; $originalName.class", + severity = Severity.ERROR, + location = SourceLocation(origin ?: "") + )) + throw SandboxClassLoadingException(context) + } + return try { logger.trace("Opening ClassReader for class {}...", originalName) - getResourceAsStream("$originalName.class").use { + getResourceAsStream("$originalName.class")?.use { ClassReader(it) - } + } ?: run(::throwClassLoadingError) } catch (exception: IOException) { - context.messages.add(Message( - message ="Class file not found; $originalName.class", - severity = Severity.ERROR, - location = SourceLocation(origin ?: "") - )) - logger.error("Failed to open ClassReader for class", exception) - throw SandboxClassLoadingException(context) + throwClassLoadingError() } } @@ -78,50 +84,16 @@ abstract class AbstractSourceClassLoader( protected companion object { @JvmStatic protected val logger = loggerFor() - - private fun resolvePaths(paths: List): Array { - return paths.map(this::expandPath).flatMap { - when { - !Files.exists(it) -> throw FileNotFoundException("File not found; $it") - Files.isDirectory(it) -> { - listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { jar -> jar.toURL() }.toList() - } - Files.isReadable(it) && isJarFile(it) -> listOf(it.toURL()) - else -> throw IllegalArgumentException("Expected JAR or class file, but found $it") - } - }.apply { - logger.trace("Resolved paths: {}", this) - }.toTypedArray() - } - - private fun expandPath(path: Path): Path { - val pathString = path.toString() - if (pathString.startsWith("~/")) { - return homeDirectory.resolve(pathString.removePrefix("~/")) - } - return path - } - - private fun isJarFile(path: Path) = path.toString().endsWith(".jar", true) - - private fun Path.toURL(): URL = this.toUri().toURL() - - private val homeDirectory: Path - get() = Paths.get(System.getProperty("user.home")) - } - } /** * Class loader to manage an optional JAR of replacement Java APIs. * @param bootstrapJar The location of the JAR containing the Java APIs. - * @param classResolver The resolver to use to derive the original name of a requested class. */ class BootstrapClassLoader( - bootstrapJar: Path, - classResolver: ClassResolver -) : AbstractSourceClassLoader(listOf(bootstrapJar), classResolver, null) { + bootstrapJar: Path +) : URLClassLoader(resolvePaths(listOf(bootstrapJar)), null) { /** * Only search our own jars for the given resource. @@ -129,6 +101,37 @@ class BootstrapClassLoader( override fun getResource(name: String): URL? = findResource(name) } +/** + * Class loader that only provides our built-in sandbox classes. + * @param classResolver The resolver to use to derive the original name of a requested class. + */ +class SandboxSourceClassLoader( + classResolver: ClassResolver, + private val bootstrap: BootstrapClassLoader +) : AbstractSourceClassLoader(emptyList(), classResolver, SandboxSourceClassLoader::class.java.classLoader) { + + /** + * Always check the bootstrap classloader first. If we're requesting + * built-in sandbox classes then delegate to our parent classloader, + * otherwise deny the request. + */ + override fun getResource(name: String): URL? { + val resource = bootstrap.findResource(name) + if (resource != null) { + return resource + } else if (isJvmInternal(name)) { + logger.error("Denying request for actual {}", name) + return null + } + + return if (name.startsWith(SANDBOX_PREFIX)) { + parent.getResource(name) + } else { + null + } + } +} + /** * Customizable class loader that allows the user to explicitly specify additional JARs and directories to scan. * @@ -168,12 +171,41 @@ class SourceClassLoader( return if (name.startsWith("net/corda/djvm/")) null else super.findResource(name) } - /** - * Does [name] exist within any of the packages reserved for Java itself? - */ - private fun isJvmInternal(name: String): Boolean = name.startsWith("java/") - || name.startsWith("javax/") - || name.startsWith("com/sun/") - || name.startsWith("sun/") - || name.startsWith("jdk/") -} \ No newline at end of file +} + +private fun resolvePaths(paths: List): Array { + return paths.map(::expandPath).flatMap { + when { + !Files.exists(it) -> throw FileNotFoundException("File not found; $it") + Files.isDirectory(it) -> { + listOf(it.toURL()) + Files.list(it).filter(::isJarFile).map { jar -> jar.toURL() }.toList() + } + Files.isReadable(it) && isJarFile(it) -> listOf(it.toURL()) + else -> throw IllegalArgumentException("Expected JAR or class file, but found $it") + } + }.toTypedArray() +} + +private fun expandPath(path: Path): Path { + val pathString = path.toString() + if (pathString.startsWith("~/")) { + return homeDirectory.resolve(pathString.removePrefix("~/")) + } + return path +} + +private fun isJarFile(path: Path) = path.toString().endsWith(".jar", true) + +private fun Path.toURL(): URL = this.toUri().toURL() + +private val homeDirectory: Path + get() = Paths.get(System.getProperty("user.home")) + +/** + * Does [name] exist within any of the packages reserved for Java itself? + */ +private fun isJvmInternal(name: String): Boolean = name.startsWith("java/") + || name.startsWith("javax/") + || name.startsWith("com/sun/") + || name.startsWith("sun/") + || name.startsWith("jdk/") diff --git a/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt b/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt index a098d78020..5aaeb6adee 100644 --- a/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt +++ b/djvm/src/main/kotlin/sandbox/java/lang/DJVM.kt @@ -2,6 +2,7 @@ @file:Suppress("unused") package sandbox.java.lang +import net.corda.djvm.SandboxRuntimeContext import net.corda.djvm.analysis.AnalysisConfiguration.Companion.JVM_EXCEPTIONS import net.corda.djvm.analysis.ExceptionResolver.Companion.getDJVMException import net.corda.djvm.rules.implementation.* @@ -42,14 +43,14 @@ fun Any.sandbox(): Any { private fun Array<*>.fromDJVMArray(): Array<*> = Object.fromDJVM(this) /** - * These functions use the "current" classloader, i.e. classloader - * that owns this DJVM class. + * Use the sandbox's classloader explicitly, because this class + * might belong to the shared parent classloader. */ @Throws(ClassNotFoundException::class) -internal fun Class<*>.toDJVMType(): Class<*> = Class.forName(name.toSandboxPackage()) +internal fun Class<*>.toDJVMType(): Class<*> = SandboxRuntimeContext.instance.classLoader.loadClass(name.toSandboxPackage()) @Throws(ClassNotFoundException::class) -internal fun Class<*>.fromDJVMType(): Class<*> = Class.forName(name.fromSandboxPackage()) +internal fun Class<*>.fromDJVMType(): Class<*> = SandboxRuntimeContext.instance.classLoader.loadClass(name.fromSandboxPackage()) private fun kotlin.String.toSandboxPackage(): kotlin.String { return if (startsWith(SANDBOX_PREFIX)) { @@ -190,10 +191,11 @@ fun fromDJVM(t: Throwable?): kotlin.Throwable { val sandboxedName = t!!.javaClass.name if (Type.getInternalName(t.javaClass) in JVM_EXCEPTIONS) { // We map these exceptions to their equivalent JVM classes. - Class.forName(sandboxedName.fromSandboxPackage()).createJavaThrowable(t) + SandboxRuntimeContext.instance.classLoader.loadClass(sandboxedName.fromSandboxPackage()) + .createJavaThrowable(t) } else { // Whereas the sandbox creates a synthetic throwable wrapper for these. - Class.forName(getDJVMException(sandboxedName)) + SandboxRuntimeContext.instance.classLoader.loadClass(getDJVMException(sandboxedName)) .getDeclaredConstructor(sandboxThrowable) .newInstance(t) as kotlin.Throwable } diff --git a/djvm/src/test/java/net/corda/djvm/execution/SandboxEnumJavaTest.java b/djvm/src/test/java/net/corda/djvm/execution/SandboxEnumJavaTest.java index 0343d9517d..4237ecf3d1 100644 --- a/djvm/src/test/java/net/corda/djvm/execution/SandboxEnumJavaTest.java +++ b/djvm/src/test/java/net/corda/djvm/execution/SandboxEnumJavaTest.java @@ -12,13 +12,11 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Stream; -import static java.util.Collections.emptySet; - public class SandboxEnumJavaTest extends TestBase { @Test public void testEnumInsideSandbox() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, TransformEnum.class, 0); assertThat(output.getResult()) @@ -29,7 +27,7 @@ public class SandboxEnumJavaTest extends TestBase { @Test public void testReturnEnumFromSandbox() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, FetchEnum.class, "THREE"); assertThat(output.getResult()) @@ -40,7 +38,7 @@ public class SandboxEnumJavaTest extends TestBase { @Test public void testWeCanIdentifyClassAsEnum() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, AssertEnum.class, ExampleEnum.THREE); assertThat(output.getResult()).isTrue(); @@ -50,7 +48,7 @@ public class SandboxEnumJavaTest extends TestBase { @Test public void testWeCanCreateEnumMap() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, UseEnumMap.class, ExampleEnum.TWO); assertThat(output.getResult()).isEqualTo(1); @@ -60,7 +58,7 @@ public class SandboxEnumJavaTest extends TestBase { @Test public void testWeCanCreateEnumSet() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, UseEnumSet.class, ExampleEnum.ONE); assertThat(output.getResult()).isTrue(); diff --git a/djvm/src/test/java/net/corda/djvm/execution/SandboxThrowableJavaTest.java b/djvm/src/test/java/net/corda/djvm/execution/SandboxThrowableJavaTest.java index 26203f641b..b6ea607dd6 100644 --- a/djvm/src/test/java/net/corda/djvm/execution/SandboxThrowableJavaTest.java +++ b/djvm/src/test/java/net/corda/djvm/execution/SandboxThrowableJavaTest.java @@ -12,14 +12,13 @@ import java.util.LinkedList; import java.util.List; import java.util.function.Function; -import static java.util.Collections.emptySet; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; public class SandboxThrowableJavaTest extends TestBase { @Test public void testUserExceptionHandling() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult output = WithJava.run(executor, ThrowAndCatchJavaExample.class, "Hello World!"); assertThat(output.getResult()) @@ -30,7 +29,7 @@ public class SandboxThrowableJavaTest extends TestBase { @Test public void testCheckedExceptions() { - sandbox(new Object[]{ DEFAULT }, emptySet(), WARNING, true, ctx -> { + parentedSandbox(WARNING, true, ctx -> { SandboxExecutor executor = new DeterministicSandboxExecutor<>(ctx.getConfiguration()); ExecutionSummaryWithResult success = WithJava.run(executor, JavaWithCheckedExceptions.class, "http://localhost:8080/hello/world"); diff --git a/djvm/src/test/kotlin/net/corda/djvm/DJVMExceptionTest.kt b/djvm/src/test/kotlin/net/corda/djvm/DJVMExceptionTest.kt index 886b1efacc..57c71b4e55 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/DJVMExceptionTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/DJVMExceptionTest.kt @@ -1,15 +1,17 @@ package net.corda.djvm +import net.corda.djvm.assertions.AssertionExtensions.assertThatDJVM +import net.corda.djvm.rewiring.SandboxClassLoadingException import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.Test import sandbox.SandboxFunction import sandbox.Task -import sandbox.java.lang.sandbox +import java.util.* -class DJVMExceptionTest { +class DJVMExceptionTest : TestBase() { @Test - fun testSingleException() { + fun testSingleException() = parentedSandbox { val result = Task(SingleExceptionTask()).apply("Hello World") assertThat(result).isInstanceOf(Throwable::class.java) result as Throwable @@ -22,7 +24,7 @@ class DJVMExceptionTest { } @Test - fun testMultipleExceptions() { + fun testMultipleExceptions() = parentedSandbox { val result = Task(MultipleExceptionsTask()).apply("Hello World") assertThat(result).isInstanceOf(Throwable::class.java) result as Throwable @@ -56,24 +58,92 @@ class DJVMExceptionTest { } @Test - fun testJavaThrowableToSandbox() { - val result = Throwable("Hello World").sandbox() - assertThat(result).isInstanceOf(sandbox.java.lang.Throwable::class.java) - result as sandbox.java.lang.Throwable + fun testJavaThrowableToSandbox() = parentedSandbox { + val djvm = DJVM(classLoader) + val helloWorld = djvm.stringOf("Hello World") - assertThat(result.message).isEqualTo("Hello World".toDJVM()) - assertThat(result.stackTrace).isNotEmpty() - assertThat(result.cause).isNull() + val result = djvm.sandbox(Throwable("Hello World")) + assertThatDJVM(result) + .hasClassName("sandbox.java.lang.Throwable") + .isAssignableFrom(djvm.throwableClass) + .hasGetterValue("getMessage", helloWorld) + .hasGetterNullValue("getCause") + + assertThat(result.getArray("getStackTrace")) + .hasOnlyElementsOfType(djvm.stackTraceElementClass) + .isNotEmpty() } @Test - fun testWeTryToCreateCorrectSandboxExceptionsAtRuntime() { + fun testWeCreateCorrectJVMExceptionAtRuntime() = parentedSandbox { + val djvm = DJVM(classLoader) + val helloWorld = djvm.stringOf("Hello World") + + val result = djvm.sandbox(RuntimeException("Hello World")) + assertThatDJVM(result) + .hasClassName("sandbox.java.lang.RuntimeException") + .isAssignableFrom(djvm.throwableClass) + .hasGetterValue("getMessage", helloWorld) + .hasGetterNullValue("getCause") + + assertThat(result.getArray("getStackTrace")) + .hasOnlyElementsOfType(djvm.stackTraceElementClass) + .isNotEmpty() + assertThatExceptionOfType(ClassNotFoundException::class.java) - .isThrownBy { Exception("Hello World").sandbox() } - .withMessage("sandbox.java.lang.Exception") + .isThrownBy { djvm.classFor("sandbox.java.lang.RuntimeException\$1DJVM") } + .withMessage("sandbox.java.lang.RuntimeException\$1DJVM") + } + + @Test + fun testWeCreateCorrectSyntheticExceptionAtRuntime() = parentedSandbox { + val djvm = DJVM(classLoader) + + val result = djvm.sandbox(EmptyStackException()) + assertThatDJVM(result) + .hasClassName("sandbox.java.util.EmptyStackException") + .isAssignableFrom(djvm.throwableClass) + .hasGetterNullValue("getMessage") + .hasGetterNullValue("getCause") + + assertThat(result.getArray("getStackTrace")) + .hasOnlyElementsOfType(djvm.stackTraceElementClass) + .isNotEmpty() + + assertThatDJVM(djvm.classFor("sandbox.java.util.EmptyStackException\$1DJVM")) + .isAssignableFrom(RuntimeException::class.java) + } + + @Test + fun testWeCannotCreateSyntheticExceptionForNonException() = parentedSandbox { + val djvm = DJVM(classLoader) assertThatExceptionOfType(ClassNotFoundException::class.java) - .isThrownBy { RuntimeException("Hello World").sandbox() } - .withMessage("sandbox.java.lang.RuntimeException") + .isThrownBy { djvm.classFor("sandbox.java.util.LinkedList\$1DJVM") } + .withMessage("sandbox.java.util.LinkedList\$1DJVM") + } + + /** + * This scenario should never happen in practice. We just need to be sure + * that the classloader can handle it. + */ + @Test + fun testWeCannotCreateSyntheticExceptionForImaginaryJavaClass() = parentedSandbox { + val djvm = DJVM(classLoader) + assertThatExceptionOfType(SandboxClassLoadingException::class.java) + .isThrownBy { djvm.classFor("sandbox.java.util.DoesNotExist\$1DJVM") } + .withMessageContaining("Failed to load class") + } + + /** + * This scenario should never happen in practice. We just need to be sure + * that the classloader can handle it. + */ + @Test + fun testWeCannotCreateSyntheticExceptionForImaginaryUserClass() = parentedSandbox { + val djvm = DJVM(classLoader) + assertThatExceptionOfType(SandboxClassLoadingException::class.java) + .isThrownBy { djvm.classFor("sandbox.com.example.DoesNotExist\$1DJVM") } + .withMessageContaining("Failed to load class") } } @@ -92,7 +162,7 @@ class MultipleExceptionsTask : SandboxFunction.toLineNumbers(): IntArray { diff --git a/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt b/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt index 37048ee7f2..80bff50e9b 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/DJVMTest.kt @@ -4,9 +4,8 @@ import org.assertj.core.api.Assertions.* import org.junit.Assert.* import org.junit.Test import sandbox.java.lang.sandbox -import sandbox.java.lang.unsandbox -class DJVMTest { +class DJVMTest : TestBase() { @Test fun testDJVMString() { @@ -16,39 +15,59 @@ class DJVMTest { } @Test - fun testSimpleIntegerFormats() { - val result = sandbox.java.lang.String.format("%d-%d-%d-%d".toDJVM(), - 10.toDJVM(), 999999L.toDJVM(), 1234.toShort().toDJVM(), 108.toByte().toDJVM()).toString() + fun testSimpleIntegerFormats() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, + stringOf("%d-%d-%d-%d"), + arrayOf(intOf(10), longOf(999999L), shortOf(1234), byteOf(108)) + ).toString() + } assertEquals("10-999999-1234-108", result) } @Test - fun testHexFormat() { - val result = sandbox.java.lang.String.format("%0#6x".toDJVM(), 768.toDJVM()).toString() + fun testHexFormat() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, stringOf("%0#6x"), arrayOf(intOf(768))).toString() + } assertEquals("0x0300", result) } @Test - fun testDoubleFormat() { - val result = sandbox.java.lang.String.format("%9.4f".toDJVM(), 1234.5678.toDJVM()).toString() + fun testDoubleFormat() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, stringOf("%9.4f"), arrayOf(doubleOf(1234.5678))).toString() + } assertEquals("1234.5678", result) } @Test - fun testFloatFormat() { - val result = sandbox.java.lang.String.format("%7.2f".toDJVM(), 1234.5678f.toDJVM()).toString() + fun testFloatFormat() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, stringOf("%7.2f"), arrayOf(floatOf(1234.5678f))).toString() + } assertEquals("1234.57", result) } @Test - fun testCharFormat() { - val result = sandbox.java.lang.String.format("[%c]".toDJVM(), 'A'.toDJVM()).toString() + fun testCharFormat() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, stringOf("[%c]"), arrayOf(charOf('A'))).toString() + } assertEquals("[A]", result) } @Test - fun testObjectFormat() { - val result = sandbox.java.lang.String.format("%s".toDJVM(), object : sandbox.java.lang.Object() {}).toString() + fun testObjectFormat() = parentedSandbox { + val result = with(DJVM(classLoader)) { + stringClass.getMethod("format", stringClass, Array::class.java) + .invoke(null, stringOf("%s"), arrayOf(object : sandbox.java.lang.Object() {})).toString() + } assertThat(result).startsWith("sandbox.java.lang.Object@") } @@ -59,48 +78,60 @@ class DJVMTest { } @Test - fun testSandboxingArrays() { - val result = arrayOf(1, 10L, "Hello World", '?', false, 1234.56).sandbox() - assertThat(result) - .isEqualTo(arrayOf(1.toDJVM(), 10L.toDJVM(), "Hello World".toDJVM(), '?'.toDJVM(), false.toDJVM(), 1234.56.toDJVM())) + fun testSandboxingArrays() = parentedSandbox { + with(DJVM(classLoader)) { + val result = sandbox(arrayOf(1, 10L, "Hello World", '?', false, 1234.56)) + assertThat(result).isEqualTo( + arrayOf(intOf(1), longOf(10), stringOf("Hello World"), charOf('?'), booleanOf(false), doubleOf(1234.56))) + } } @Test - fun testUnsandboxingObjectArray() { - val result = arrayOf(1.toDJVM(), 10L.toDJVM(), "Hello World".toDJVM(), '?'.toDJVM(), false.toDJVM(), 1234.56.toDJVM()).unsandbox() + fun testUnsandboxingObjectArray() = parentedSandbox { + val result = with(DJVM(classLoader)) { + unsandbox(arrayOf(intOf(1), longOf(10L), stringOf("Hello World"), charOf('?'), booleanOf(false), doubleOf(1234.56))) + } assertThat(result) - .isEqualTo(arrayOf(1, 10L, "Hello World", '?', false, 1234.56)) + .isEqualTo(arrayOf(1, 10L, "Hello World", '?', false, 1234.56)) } @Test - fun testSandboxingPrimitiveArray() { - val result = intArrayOf(1, 2, 3, 10).sandbox() + fun testSandboxingPrimitiveArray() = parentedSandbox { + val result = with(DJVM(classLoader)) { + sandbox(intArrayOf(1, 2, 3, 10)) + } assertThat(result).isEqualTo(intArrayOf(1, 2, 3, 10)) } @Test - fun testSandboxingIntegersAsObjectArray() { - val result = arrayOf(1, 2, 3, 10).sandbox() - assertThat(result).isEqualTo(arrayOf(1.toDJVM(), 2.toDJVM(), 3.toDJVM(), 10.toDJVM())) + fun testSandboxingIntegersAsObjectArray() = parentedSandbox { + with(DJVM(classLoader)) { + val result = sandbox(arrayOf(1, 2, 3, 10)) + assertThat(result).isEqualTo( + arrayOf(intOf(1), intOf(2), intOf(3), intOf(10)) + ) + } } @Test - fun testUnsandboxingArrays() { - val arr = arrayOf( - Array(1) { "Hello".toDJVM() }, - Array(1) { 1234000L.toDJVM() }, - Array(1) { 1234.toDJVM() }, - Array(1) { 923.toShort().toDJVM() }, - Array(1) { 27.toByte().toDJVM() }, - Array(1) { 'X'.toDJVM() }, - Array(1) { 987.65f.toDJVM() }, - Array(1) { 343.282.toDJVM() }, - Array(1) { true.toDJVM() }, - ByteArray(1) { 127.toByte() }, - CharArray(1) { '?'} - ) - val result = arr.unsandbox() as Array<*> - assertEquals(arr.size, result.size) + fun testUnsandboxingArrays() = parentedSandbox { + val (array, result) = with(DJVM(classLoader)) { + val arr = arrayOf( + objectArrayOf(stringOf("Hello")), + objectArrayOf(longOf(1234000L)), + objectArrayOf(intOf(1234)), + objectArrayOf(shortOf(923)), + objectArrayOf(byteOf(27)), + objectArrayOf(charOf('X')), + objectArrayOf(floatOf(987.65f)), + objectArrayOf(doubleOf(343.282)), + objectArrayOf(booleanOf(true)), + ByteArray(1) { 127.toByte() }, + CharArray(1) { '?' } + ) + Pair(arr, unsandbox(arr) as Array<*>) + } + assertEquals(array.size, result.size) assertArrayEquals(Array(1) { "Hello" }, result[0] as Array<*>) assertArrayEquals(Array(1) { 1234000L }, result[1] as Array<*>) assertArrayEquals(Array(1) { 1234 }, result[2] as Array<*>) diff --git a/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt b/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt index ad16eee53a..155015930d 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/TestBase.kt @@ -12,13 +12,18 @@ import net.corda.djvm.execution.ExecutionProfile import net.corda.djvm.messages.Severity import net.corda.djvm.references.ClassHierarchy import net.corda.djvm.rewiring.LoadedClass +import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.rules.Rule import net.corda.djvm.rules.implementation.* +import net.corda.djvm.source.BootstrapClassLoader import net.corda.djvm.source.ClassSource +import net.corda.djvm.source.SandboxSourceClassLoader import net.corda.djvm.utilities.Discovery import net.corda.djvm.validation.RuleValidator import org.junit.After +import org.junit.AfterClass import org.junit.Assert.assertEquals +import org.junit.BeforeClass import org.objectweb.asm.ClassReader import org.objectweb.asm.ClassWriter import org.objectweb.asm.Type @@ -39,6 +44,7 @@ abstract class TestBase { // We need at least these emitters to handle the Java API classes. @JvmField val BASIC_EMITTERS: List = listOf( + AlwaysInheritFromSandboxedObject(), ArgumentUnwrapper(), HandleExceptionUnwrapper(), ReturnTypeWrapper(), @@ -51,7 +57,10 @@ abstract class TestBase { // We need at least these providers to handle the Java API classes. @JvmField - val BASIC_DEFINITION_PROVIDERS: List = listOf(StaticConstantRemover()) + val BASIC_DEFINITION_PROVIDERS: List = listOf( + AlwaysInheritFromSandboxedObject(), + StaticConstantRemover() + ) @JvmField val BLANK = emptySet() @@ -63,17 +72,52 @@ abstract class TestBase { val DETERMINISTIC_RT: Path = Paths.get( System.getProperty("deterministic-rt.path") ?: throw AssertionError("deterministic-rt.path property not set")) + private lateinit var parentConfiguration: SandboxConfiguration + lateinit var parentClassLoader: SandboxClassLoader + /** * Get the full name of type [T]. */ inline fun nameOf(prefix: String = "") = "$prefix${Type.getInternalName(T::class.java)}" + @BeforeClass + @JvmStatic + fun setupParentClassLoader() { + val rootConfiguration = AnalysisConfiguration.createRoot( + Whitelist.MINIMAL, + bootstrapClassLoader = BootstrapClassLoader(DETERMINISTIC_RT), + sourceClassLoaderFactory = { classResolver, bootstrapClassLoader -> + SandboxSourceClassLoader(classResolver, bootstrapClassLoader!!) + }, + additionalPinnedClasses = setOf( + Utilities::class.java + ).map(Type::getInternalName).toSet() + ) + parentConfiguration = SandboxConfiguration.of( + ExecutionProfile.UNLIMITED, + ALL_RULES, + ALL_EMITTERS, + ALL_DEFINITION_PROVIDERS, + true, + rootConfiguration + ) + parentClassLoader = SandboxClassLoader.createFor(parentConfiguration) + } + + @AfterClass + @JvmStatic + fun destroyRootContext() { + parentConfiguration.analysisConfiguration.close() + } } /** * Default analysis configuration. */ - val configuration = AnalysisConfiguration(Whitelist.MINIMAL, bootstrapJar = DETERMINISTIC_RT) + val configuration = AnalysisConfiguration.createRoot( + Whitelist.MINIMAL, + bootstrapClassLoader = BootstrapClassLoader(DETERMINISTIC_RT) + ) /** * Default analysis context @@ -94,9 +138,9 @@ abstract class TestBase { noinline block: (RuleValidator.(AnalysisContext) -> Unit) ) { val reader = ClassReader(T::class.java.name) - AnalysisConfiguration( + AnalysisConfiguration.createRoot( minimumSeverityLevel = minimumSeverityLevel, - bootstrapJar = DETERMINISTIC_RT + bootstrapClassLoader = BootstrapClassLoader(DETERMINISTIC_RT) ).use { analysisConfiguration -> val validator = RuleValidator(ALL_RULES, analysisConfiguration) val context = AnalysisContext.fromConfiguration(analysisConfiguration) @@ -110,11 +154,11 @@ abstract class TestBase { * the current thread, so this allows inspection of the cost summary object, etc. from within the provided delegate. */ fun sandbox( - vararg options: Any, - pinnedClasses: Set> = emptySet(), - minimumSeverityLevel: Severity = Severity.WARNING, - enableTracing: Boolean = true, - action: SandboxRuntimeContext.() -> Unit + vararg options: Any, + pinnedClasses: Set> = emptySet(), + minimumSeverityLevel: Severity = Severity.WARNING, + enableTracing: Boolean = true, + action: SandboxRuntimeContext.() -> Unit ) { val rules = mutableListOf() val emitters = mutableListOf().apply { addAll(BASIC_EMITTERS) } @@ -141,11 +185,11 @@ abstract class TestBase { thread { try { val pinnedTestClasses = pinnedClasses.map(Type::getInternalName).toSet() - AnalysisConfiguration( + AnalysisConfiguration.createRoot( whitelist = whitelist, - bootstrapJar = DETERMINISTIC_RT, additionalPinnedClasses = pinnedTestClasses, - minimumSeverityLevel = minimumSeverityLevel + minimumSeverityLevel = minimumSeverityLevel, + bootstrapClassLoader = BootstrapClassLoader(DETERMINISTIC_RT) ).use { analysisConfiguration -> SandboxRuntimeContext(SandboxConfiguration.of( executionProfile, @@ -166,6 +210,37 @@ abstract class TestBase { throw thrownException ?: return } + fun parentedSandbox( + minimumSeverityLevel: Severity = Severity.WARNING, + enableTracing: Boolean = true, + action: SandboxRuntimeContext.() -> Unit + ) { + var thrownException: Throwable? = null + thread { + try { + parentConfiguration.analysisConfiguration.createChild( + newMinimumSeverityLevel = minimumSeverityLevel + ).use { analysisConfiguration -> + SandboxRuntimeContext(SandboxConfiguration.of( + parentConfiguration.executionProfile, + parentConfiguration.rules, + parentConfiguration.emitters, + parentConfiguration.definitionProviders, + enableTracing, + analysisConfiguration, + parentClassLoader + )).use { + assertThat(runtimeCosts).areZero() + action(this) + } + } + } catch (exception: Throwable) { + thrownException = exception + } + }.join() + throw thrownException ?: return + } + /** * Get a class reference from a class hierarchy based on [T]. */ @@ -178,8 +253,7 @@ abstract class TestBase { inline fun SandboxRuntimeContext.loadClass(): LoadedClass = loadClass(T::class.jvmName) - fun SandboxRuntimeContext.loadClass(className: String): LoadedClass = - classLoader.loadForSandbox(className, context) + fun SandboxRuntimeContext.loadClass(className: String): LoadedClass = classLoader.loadForSandbox(className) /** * Run the entry-point of the loaded [Callable] class. @@ -203,4 +277,77 @@ abstract class TestBase { } } + @Suppress("MemberVisibilityCanBePrivate") + protected class DJVM(private val classLoader: ClassLoader) { + private val djvm: Class<*> = classFor("sandbox.java.lang.DJVM") + val objectClass: Class<*> by lazy { classFor("sandbox.java.lang.Object") } + val stringClass: Class<*> by lazy { classFor("sandbox.java.lang.String") } + val longClass: Class<*> by lazy { classFor("sandbox.java.lang.Long") } + val integerClass: Class<*> by lazy { classFor("sandbox.java.lang.Integer") } + val shortClass: Class<*> by lazy { classFor("sandbox.java.lang.Short") } + val byteClass: Class<*> by lazy { classFor("sandbox.java.lang.Byte") } + val characterClass: Class<*> by lazy { classFor("sandbox.java.lang.Character") } + val booleanClass: Class<*> by lazy { classFor("sandbox.java.lang.Boolean") } + val doubleClass: Class<*> by lazy { classFor("sandbox.java.lang.Double") } + val floatClass: Class<*> by lazy { classFor("sandbox.java.lang.Float") } + val throwableClass: Class<*> by lazy { classFor("sandbox.java.lang.Throwable") } + val stackTraceElementClass: Class<*> by lazy { classFor("sandbox.java.lang.StackTraceElement") } + + fun classFor(className: String): Class<*> = Class.forName(className, false, classLoader) + + fun sandbox(obj: Any): Any { + return djvm.getMethod("sandbox", Any::class.java).invoke(null, obj) + } + + fun unsandbox(obj: Any): Any { + return djvm.getMethod("unsandbox", Any::class.java).invoke(null, obj) + } + + fun stringOf(str: String): Any { + return stringClass.getMethod("toDJVM", String::class.java).invoke(null, str) + } + + fun longOf(l: Long): Any { + return longClass.getMethod("toDJVM", Long::class.javaObjectType).invoke(null, l) + } + + fun intOf(i: Int): Any { + return integerClass.getMethod("toDJVM", Int::class.javaObjectType).invoke(null, i) + } + + fun shortOf(i: Int): Any { + return shortClass.getMethod("toDJVM", Short::class.javaObjectType).invoke(null, i.toShort()) + } + + fun byteOf(i: Int): Any { + return byteClass.getMethod("toDJVM", Byte::class.javaObjectType).invoke(null, i.toByte()) + } + + fun charOf(c: Char): Any { + return characterClass.getMethod("toDJVM", Char::class.javaObjectType).invoke(null, c) + } + + fun booleanOf(bool: Boolean): Any { + return booleanClass.getMethod("toDJVM", Boolean::class.javaObjectType).invoke(null, bool) + } + + fun doubleOf(d: Double): Any { + return doubleClass.getMethod("toDJVM", Double::class.javaObjectType).invoke(null, d) + } + + fun floatOf(f: Float): Any { + return floatClass.getMethod("toDJVM", Float::class.javaObjectType).invoke(null, f) + } + + fun objectArrayOf(vararg objs: Any): Array { + @Suppress("unchecked_cast") + return (java.lang.reflect.Array.newInstance(objectClass, objs.size) as Array).also { + for (i in 0 until objs.size) { + it[i] = objectClass.cast(objs[i]) + } + } + } + } + + fun Any.getArray(methodName: String): Array<*> = javaClass.getMethod(methodName).invoke(this) as Array<*> } diff --git a/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt b/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt index 6313661b0c..fb5b36c998 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/Utilities.kt @@ -12,13 +12,3 @@ object Utilities { fun throwThresholdViolationError(): Nothing = throw ThresholdViolationError("Can't catch this!") } - -fun String.toDJVM(): sandbox.java.lang.String = sandbox.java.lang.String.toDJVM(this) -fun Long.toDJVM(): sandbox.java.lang.Long = sandbox.java.lang.Long.toDJVM(this) -fun Int.toDJVM(): sandbox.java.lang.Integer = sandbox.java.lang.Integer.toDJVM(this) -fun Short.toDJVM(): sandbox.java.lang.Short = sandbox.java.lang.Short.toDJVM(this) -fun Byte.toDJVM(): sandbox.java.lang.Byte = sandbox.java.lang.Byte.toDJVM(this) -fun Float.toDJVM(): sandbox.java.lang.Float = sandbox.java.lang.Float.toDJVM(this) -fun Double.toDJVM(): sandbox.java.lang.Double = sandbox.java.lang.Double.toDJVM(this) -fun Char.toDJVM(): sandbox.java.lang.Character = sandbox.java.lang.Character.toDJVM(this) -fun Boolean.toDJVM(): sandbox.java.lang.Boolean = sandbox.java.lang.Boolean.toDJVM(this) \ No newline at end of file diff --git a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertionExtensions.kt b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertionExtensions.kt index a86706551a..09df337ff3 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertionExtensions.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertionExtensions.kt @@ -36,6 +36,8 @@ object AssertionExtensions { fun assertThat(references: ReferenceMap) = AssertiveReferenceMap(references) + fun assertThatDJVM(obj: Any) = AssertiveDJVMObject(obj) + inline fun IterableAssert.hasClass(): IterableAssert = this .`as`("HasClass(${T::class.java.name})") .anySatisfy { diff --git a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt index 0957217f47..080968490f 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveClassWithByteCode.kt @@ -24,6 +24,11 @@ class AssertiveClassWithByteCode(private val loadedClass: LoadedClass) { return this } + fun hasClassLoader(classLoader: ClassLoader): AssertiveClassWithByteCode { + assertThat(loadedClass.type.classLoader).isEqualTo(classLoader) + return this + } + fun hasClassName(className: String): AssertiveClassWithByteCode { assertThat(loadedClass.type.name).isEqualTo(className) return this diff --git a/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveDJVMObject.kt b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveDJVMObject.kt new file mode 100644 index 0000000000..337597f2e2 --- /dev/null +++ b/djvm/src/test/kotlin/net/corda/djvm/assertions/AssertiveDJVMObject.kt @@ -0,0 +1,26 @@ +package net.corda.djvm.assertions + +import org.assertj.core.api.Assertions.* + +class AssertiveDJVMObject(private val djvmObj: Any) { + + fun hasClassName(className: String): AssertiveDJVMObject { + assertThat(djvmObj.javaClass.name).isEqualTo(className) + return this + } + + fun isAssignableFrom(clazz: Class<*>): AssertiveDJVMObject { + assertThat(djvmObj.javaClass.isAssignableFrom(clazz)) + return this + } + + fun hasGetterValue(methodName: String, value: Any): AssertiveDJVMObject { + assertThat(djvmObj.javaClass.getMethod(methodName).invoke(djvmObj)).isEqualTo(value) + return this + } + + fun hasGetterNullValue(methodName: String): AssertiveDJVMObject { + assertThat(djvmObj.javaClass.getMethod(methodName).invoke(djvmObj)).isNull() + return this + } +} \ No newline at end of file diff --git a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt index af78c3183b..8f36267c90 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxEnumTest.kt @@ -8,7 +8,7 @@ import java.util.function.Function class SandboxEnumTest : TestBase() { @Test - fun `test enum inside sandbox`() = sandbox(DEFAULT) { + fun `test enum inside sandbox`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) contractExecutor.run(0).apply { assertThat(result).isEqualTo(arrayOf("ONE", "TWO", "THREE")) @@ -16,7 +16,7 @@ class SandboxEnumTest : TestBase() { } @Test - fun `return enum from sandbox`() = sandbox(DEFAULT) { + fun `return enum from sandbox`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run("THREE").apply { assertThat(result).isEqualTo(ExampleEnum.THREE) @@ -24,7 +24,7 @@ class SandboxEnumTest : TestBase() { } @Test - fun `test we can identify class as Enum`() = sandbox(DEFAULT) { + fun `test we can identify class as Enum`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(ExampleEnum.THREE).apply { assertThat(result).isTrue() @@ -32,7 +32,7 @@ class SandboxEnumTest : TestBase() { } @Test - fun `test we can create EnumMap`() = sandbox(DEFAULT) { + fun `test we can create EnumMap`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(ExampleEnum.TWO).apply { assertThat(result).isEqualTo(1) @@ -40,7 +40,7 @@ class SandboxEnumTest : TestBase() { } @Test - fun `test we can create EnumSet`() = sandbox(DEFAULT) { + fun `test we can create EnumSet`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(ExampleEnum.ONE).apply { assertThat(result).isTrue() diff --git a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt index 32fa876195..169b956e85 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxExecutorTest.kt @@ -4,7 +4,7 @@ import foo.bar.sandbox.MyObject import foo.bar.sandbox.testClock import foo.bar.sandbox.toNumber import net.corda.djvm.TestBase -import net.corda.djvm.analysis.Whitelist +import net.corda.djvm.analysis.Whitelist.Companion.MINIMAL import net.corda.djvm.Utilities import net.corda.djvm.Utilities.throwRuleViolationError import net.corda.djvm.Utilities.throwThresholdViolationError @@ -22,7 +22,7 @@ import java.util.stream.Collectors.* class SandboxExecutorTest : TestBase() { @Test - fun `can load and execute runnable`() = sandbox(Whitelist.MINIMAL) { + fun `can load and execute runnable`() = sandbox(MINIMAL) { val contractExecutor = DeterministicSandboxExecutor(configuration) val summary = contractExecutor.run(1) val result = summary.result @@ -36,7 +36,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute contract`() = sandbox( + fun `can load and execute contract`() = sandbox(DEFAULT, pinnedClasses = setOf(Transaction::class.java, Utilities::class.java) ) { val contractExecutor = DeterministicSandboxExecutor(configuration) @@ -56,7 +56,7 @@ class SandboxExecutorTest : TestBase() { data class Transaction(val id: Int) @Test - fun `can load and execute code that overrides object hash code`() = sandbox(DEFAULT) { + fun `can load and execute code that overrides object hash code`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) val summary = contractExecutor.run(0) val result = summary.result @@ -74,7 +74,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that overrides object hash code when derived`() = sandbox(DEFAULT) { + fun `can load and execute code that overrides object hash code when derived`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) val summary = contractExecutor.run(0) val result = summary.result @@ -107,7 +107,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can detect stack overflow`() = sandbox(DEFAULT) { + fun `can detect stack overflow`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } @@ -141,7 +141,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that references non-deterministic code`() = sandbox(DEFAULT) { + fun `cannot execute runnable that references non-deterministic code`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } @@ -156,7 +156,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that catches ThreadDeath`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot execute runnable that catches ThreadDeath`() = parentedSandbox { TestCatchThreadDeath().apply { assertThat(apply(0)).isEqualTo(1) } @@ -178,7 +178,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that catches ThresholdViolationError`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot execute runnable that catches ThresholdViolationError`() = parentedSandbox { TestCatchThresholdViolationError().apply { assertThat(apply(0)).isEqualTo(1) } @@ -201,7 +201,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot execute runnable that catches RuleViolationError`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot execute runnable that catches RuleViolationError`() = parentedSandbox { TestCatchRuleViolationError().apply { assertThat(apply(0)).isEqualTo(1) } @@ -224,7 +224,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can catch Throwable`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `can catch Throwable`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(1).apply { assertThat(result).isEqualTo(1) @@ -232,7 +232,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can catch Error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `can catch Error`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(2).apply { assertThat(result).isEqualTo(2) @@ -240,7 +240,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch ThreadDeath`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot catch ThreadDeath`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(3) } @@ -295,7 +295,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch stack-overflow error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot catch stack-overflow error`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(4) } @@ -304,7 +304,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot catch out-of-memory error`() = sandbox(DEFAULT, pinnedClasses = setOf(Utilities::class.java)) { + fun `cannot catch out-of-memory error`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(5) } @@ -313,7 +313,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `cannot persist state across sessions`() = sandbox(DEFAULT) { + fun `cannot persist state across sessions`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) val result1 = contractExecutor.run(0) val result2 = contractExecutor.run(0) @@ -335,7 +335,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses IO`() = sandbox(DEFAULT) { + fun `can load and execute code that uses IO`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } @@ -354,7 +354,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses reflection`() = sandbox(DEFAULT) { + fun `can load and execute code that uses reflection`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(0) } @@ -373,7 +373,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses notify()`() = sandbox(DEFAULT) { + fun `can load and execute code that uses notify()`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(1) } @@ -383,7 +383,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses notifyAll()`() = sandbox(DEFAULT) { + fun `can load and execute code that uses notifyAll()`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(2) } @@ -393,7 +393,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses wait()`() = sandbox(DEFAULT) { + fun `can load and execute code that uses wait()`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(3) } @@ -403,7 +403,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses wait(long)`() = sandbox(DEFAULT) { + fun `can load and execute code that uses wait(long)`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(4) } @@ -413,7 +413,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that uses wait(long,int)`() = sandbox(DEFAULT) { + fun `can load and execute code that uses wait(long,int)`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run(5) } @@ -423,7 +423,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `code after forbidden APIs is intact`() = sandbox(DEFAULT) { + fun `code after forbidden APIs is intact`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) assertThat(contractExecutor.run(0).result) .isEqualTo("unknown") @@ -462,7 +462,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute code that has a native method`() = sandbox(DEFAULT) { + fun `can load and execute code that has a native method`() = parentedSandbox { assertThatExceptionOfType(UnsatisfiedLinkError::class.java) .isThrownBy { TestNativeMethod().apply(0) } .withMessageContaining("TestNativeMethod.evilDeeds()I") @@ -483,7 +483,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `check arrays still work`() = sandbox(DEFAULT) { + fun `check arrays still work`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) contractExecutor.run(100).apply { assertThat(result).isEqualTo(arrayOf(100)) @@ -497,7 +497,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `check building a string`() = sandbox(DEFAULT) { + fun `check building a string`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run("Hello Sandbox!").apply { assertThat(result) @@ -522,7 +522,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `check System-arraycopy still works with Objects`() = sandbox(DEFAULT) { + fun `check System-arraycopy still works with Objects`() = parentedSandbox { val source = arrayOf("one", "two", "three") assertThat(TestArrayCopy().apply(source)) .isEqualTo(source) @@ -545,7 +545,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test System-arraycopy still works with CharArray`() = sandbox(DEFAULT) { + fun `test System-arraycopy still works with CharArray`() = parentedSandbox { val source = CharArray(10) { '?' } val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(source).apply { @@ -564,7 +564,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can load and execute class that has finalize`() = sandbox(DEFAULT) { + fun `can load and execute class that has finalize`() = parentedSandbox { assertThatExceptionOfType(UnsupportedOperationException::class.java) .isThrownBy { TestFinalizeMethod().apply(100) } .withMessageContaining("Very Bad Thing") @@ -587,7 +587,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `can execute parallel stream`() = sandbox(DEFAULT) { + fun `can execute parallel stream`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run("Pebble").apply { assertThat(result).isEqualTo("Five,Four,One,Pebble,Three,Two") @@ -605,7 +605,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `users cannot load our sandboxed classes`() = sandbox(DEFAULT) { + fun `users cannot load our sandboxed classes`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run("java.lang.DJVM") } @@ -614,7 +614,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `users can load sandboxed classes`() = sandbox(DEFAULT) { + fun `users can load sandboxed classes`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) contractExecutor.run("java.util.List").apply { assertThat(result?.name).isEqualTo("sandbox.java.util.List") @@ -628,7 +628,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test case-insensitive string sorting`() = sandbox(DEFAULT) { + fun `test case-insensitive string sorting`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor, Array>(configuration) contractExecutor.run(arrayOf("Zelda", "angela", "BOB", "betsy", "ALBERT")).apply { assertThat(result).isEqualTo(arrayOf("ALBERT", "angela", "betsy", "BOB", "Zelda")) @@ -642,7 +642,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test unicode characters`() = sandbox(DEFAULT) { + fun `test unicode characters`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(0x01f600).apply { assertThat(result).isEqualTo("EMOTICONS") @@ -656,7 +656,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test unicode scripts`() = sandbox(DEFAULT) { + fun `test unicode scripts`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run("COMMON").apply { assertThat(result).isEqualTo(Character.UnicodeScript.COMMON) @@ -671,7 +671,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test users cannot define new classes`() = sandbox(DEFAULT) { + fun `test users cannot define new classes`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } @@ -693,7 +693,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test users cannot load new classes`() = sandbox(DEFAULT) { + fun `test users cannot load new classes`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } @@ -714,7 +714,7 @@ class SandboxExecutorTest : TestBase() { } @Test - fun `test users cannot lookup classes`() = sandbox(DEFAULT) { + fun `test users cannot lookup classes`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) assertThatExceptionOfType(SandboxException::class.java) .isThrownBy { contractExecutor.run("sandbox.java.lang.DJVM") } diff --git a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxThrowableTest.kt b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxThrowableTest.kt index ae013a9c1e..69608b40b5 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxThrowableTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/execution/SandboxThrowableTest.kt @@ -8,7 +8,7 @@ import java.util.function.Function class SandboxThrowableTest : TestBase() { @Test - fun `test user exception handling`() = sandbox(DEFAULT) { + fun `test user exception handling`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) contractExecutor.run("Hello World").apply { assertThat(result) @@ -17,7 +17,7 @@ class SandboxThrowableTest : TestBase() { } @Test - fun `test rethrowing an exception`() = sandbox(DEFAULT) { + fun `test rethrowing an exception`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor>(configuration) contractExecutor.run("Hello World").apply { assertThat(result) @@ -26,7 +26,7 @@ class SandboxThrowableTest : TestBase() { } @Test - fun `test JVM exceptions still propagate`() = sandbox(DEFAULT) { + fun `test JVM exceptions still propagate`() = parentedSandbox { val contractExecutor = DeterministicSandboxExecutor(configuration) contractExecutor.run(-1).apply { assertThat(result) diff --git a/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt b/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt index 68b60da1cd..8e1b013285 100644 --- a/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt +++ b/djvm/src/test/kotlin/net/corda/djvm/rewiring/ClassRewriterTest.kt @@ -15,7 +15,7 @@ class ClassRewriterTest : TestBase() { @Test fun `empty transformer does nothing`() = sandbox(BLANK) { val callable = newCallable() - assertThat(callable).hasNotBeenModified() + assertThat(callable).isSandboxed() callable.createAndInvoke() assertThat(runtimeCosts).areZero() } @@ -130,6 +130,32 @@ class ClassRewriterTest : TestBase() { .hasInterface("sandbox.java.lang.CharSequence") .hasBeenModified() } + + @Test + fun `test Java class is owned by parent classloader`() = parentedSandbox { + val stringBuilderClass = loadClass().type + assertThat(stringBuilderClass.classLoader).isEqualTo(parentClassLoader) + } + + @Test + fun `test user class is owned by new classloader`() = parentedSandbox { + assertThat(loadClass()) + .hasClassLoader(classLoader) + .hasBeenModified() + } + + @Test + fun `test template class is owned by parent classloader`() = parentedSandbox { + assertThat(classLoader.loadForSandbox("sandbox.java.lang.DJVM")) + .hasClassLoader(parentClassLoader) + .hasNotBeenModified() + } + + @Test + fun `test pinned class is owned by application classloader`() = parentedSandbox { + val violationClass = loadClass().type + assertThat(violationClass).isEqualTo(ThresholdViolationError::class.java) + } } @Suppress("unused")