[CORDA-2431] - Small refactorings following-up on PR-4551 (#4564)

* Small refactorings following-up on PR-4551

* Adjust thread context class loader

* Address Shams' comments
This commit is contained in:
Dimos Raptis
2019-01-15 14:34:11 +00:00
committed by GitHub
parent 05ffb3d101
commit fbb00bff9c
9 changed files with 127 additions and 40 deletions

View File

@ -2,7 +2,6 @@ package net.corda.core.internal
import io.github.classgraph.ClassGraph
import net.corda.core.StubOutForDJVM
import kotlin.reflect.full.createInstance
/**
* Creates instances of all the classes in the classpath of the provided classloader, which implement the interface of the provided class.
@ -17,15 +16,26 @@ import kotlin.reflect.full.createInstance
* - either be a Kotlin object or have a constructor with no parameters (or only optional ones)
*/
@StubOutForDJVM
fun <T: Any> loadClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
fun <T: Any> createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
return ClassGraph().addClassLoader(classloader)
.enableClassInfo()
.scan()
.use {
it.getClassesImplementing(clazz.name)
.filterNot { it.isAbstract }
.mapNotNull { classloader.loadClass(it.name).asSubclass(clazz) }
.map { it.kotlin.objectInstance ?: it.kotlin.createInstance() }
.map { classloader.loadClass(it.name).asSubclass(clazz) }
.map { it.kotlin.objectOrNewInstance() }
.toSet()
}
}
fun <T: Any?> executeWithThreadContextClassLoader(classloader: ClassLoader, fn: () -> T): T {
val threadClassLoader = Thread.currentThread().contextClassLoader
try {
Thread.currentThread().contextClassLoader = classloader
return fn()
} finally {
Thread.currentThread().contextClassLoader = threadClassLoader
}
}

View File

@ -2,7 +2,7 @@ package net.corda.core.serialization.internal
import net.corda.core.CordaException
import net.corda.core.KeepForDJVM
import net.corda.core.internal.loadClassesImplementing
import net.corda.core.internal.createInstancesOfClassesImplementing
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException
@ -11,7 +11,6 @@ import net.corda.core.crypto.sha256
import net.corda.core.internal.*
import net.corda.core.internal.cordapp.targetPlatformVersion
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializationWhitelist
@ -196,7 +195,7 @@ internal object AttachmentsClassLoaderBuilder {
val serializationContext = cache.computeIfAbsent(attachmentIds) {
// Create classloader and load serializers, whitelisted classes
val transactionClassLoader = AttachmentsClassLoader(attachments)
val serializers = loadClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java)
val serializers = createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java)
val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader)
.flatMap { it.whitelist }
.toList()

View File

@ -1,11 +1,16 @@
package net.corda.core.internal
import com.nhaarman.mockito_kotlin.mock
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.lang.IllegalArgumentException
import java.lang.RuntimeException
class ClassLoadingUtilsTest {
private val temporaryClassLoader = mock<ClassLoader>()
interface BaseInterface {}
interface BaseInterface2 {}
@ -18,7 +23,7 @@ class ClassLoadingUtilsTest {
@Test
fun predicateClassAreLoadedSuccessfully() {
val classes = loadClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface::class.java)
val classes = createInstancesOfClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface::class.java)
val classNames = classes.map { it.javaClass.name }
@ -28,7 +33,27 @@ class ClassLoadingUtilsTest {
@Test(expected = IllegalArgumentException::class)
fun throwsExceptionWhenClassDoesNotContainProperConstructors() {
val classes = loadClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface2::class.java)
val classes = createInstancesOfClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface2::class.java)
}
@Test
fun `thread context class loader is adjusted, during the function execution`() {
val result = executeWithThreadContextClassLoader(temporaryClassLoader) {
assertThat(Thread.currentThread().contextClassLoader).isEqualTo(temporaryClassLoader)
true
}
assertThat(result).isTrue()
assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader)
}
@Test
fun `thread context class loader is set to the initial, even in case of a failure`() {
assertThatThrownBy { executeWithThreadContextClassLoader(temporaryClassLoader) {
throw RuntimeException()
} }.isInstanceOf(RuntimeException::class.java)
assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader)
}
}