mirror of
https://github.com/corda/corda.git
synced 2025-06-23 01:19:00 +00:00
[CORDA-2431] - Small refactorings following-up on PR-4551 (#4564)
* Small refactorings following-up on PR-4551 * Adjust thread context class loader * Address Shams' comments
This commit is contained in:
@ -2,7 +2,6 @@ package net.corda.core.internal
|
||||
|
||||
import io.github.classgraph.ClassGraph
|
||||
import net.corda.core.StubOutForDJVM
|
||||
import kotlin.reflect.full.createInstance
|
||||
|
||||
/**
|
||||
* Creates instances of all the classes in the classpath of the provided classloader, which implement the interface of the provided class.
|
||||
@ -17,15 +16,26 @@ import kotlin.reflect.full.createInstance
|
||||
* - either be a Kotlin object or have a constructor with no parameters (or only optional ones)
|
||||
*/
|
||||
@StubOutForDJVM
|
||||
fun <T: Any> loadClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
|
||||
fun <T: Any> createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class<T>): Set<T> {
|
||||
return ClassGraph().addClassLoader(classloader)
|
||||
.enableClassInfo()
|
||||
.scan()
|
||||
.use {
|
||||
it.getClassesImplementing(clazz.name)
|
||||
.filterNot { it.isAbstract }
|
||||
.mapNotNull { classloader.loadClass(it.name).asSubclass(clazz) }
|
||||
.map { it.kotlin.objectInstance ?: it.kotlin.createInstance() }
|
||||
.map { classloader.loadClass(it.name).asSubclass(clazz) }
|
||||
.map { it.kotlin.objectOrNewInstance() }
|
||||
.toSet()
|
||||
}
|
||||
}
|
||||
|
||||
fun <T: Any?> executeWithThreadContextClassLoader(classloader: ClassLoader, fn: () -> T): T {
|
||||
val threadClassLoader = Thread.currentThread().contextClassLoader
|
||||
try {
|
||||
Thread.currentThread().contextClassLoader = classloader
|
||||
return fn()
|
||||
} finally {
|
||||
Thread.currentThread().contextClassLoader = threadClassLoader
|
||||
}
|
||||
|
||||
}
|
@ -2,7 +2,7 @@ package net.corda.core.serialization.internal
|
||||
|
||||
import net.corda.core.CordaException
|
||||
import net.corda.core.KeepForDJVM
|
||||
import net.corda.core.internal.loadClassesImplementing
|
||||
import net.corda.core.internal.createInstancesOfClassesImplementing
|
||||
import net.corda.core.contracts.Attachment
|
||||
import net.corda.core.contracts.ContractAttachment
|
||||
import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException
|
||||
@ -11,7 +11,6 @@ import net.corda.core.crypto.sha256
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.cordapp.targetPlatformVersion
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.MissingAttachmentsException
|
||||
import net.corda.core.serialization.SerializationCustomSerializer
|
||||
import net.corda.core.serialization.SerializationFactory
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
@ -196,7 +195,7 @@ internal object AttachmentsClassLoaderBuilder {
|
||||
val serializationContext = cache.computeIfAbsent(attachmentIds) {
|
||||
// Create classloader and load serializers, whitelisted classes
|
||||
val transactionClassLoader = AttachmentsClassLoader(attachments)
|
||||
val serializers = loadClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java)
|
||||
val serializers = createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java)
|
||||
val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader)
|
||||
.flatMap { it.whitelist }
|
||||
.toList()
|
||||
|
@ -1,11 +1,16 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.Test
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.lang.RuntimeException
|
||||
|
||||
class ClassLoadingUtilsTest {
|
||||
|
||||
private val temporaryClassLoader = mock<ClassLoader>()
|
||||
|
||||
interface BaseInterface {}
|
||||
|
||||
interface BaseInterface2 {}
|
||||
@ -18,7 +23,7 @@ class ClassLoadingUtilsTest {
|
||||
|
||||
@Test
|
||||
fun predicateClassAreLoadedSuccessfully() {
|
||||
val classes = loadClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface::class.java)
|
||||
val classes = createInstancesOfClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface::class.java)
|
||||
|
||||
val classNames = classes.map { it.javaClass.name }
|
||||
|
||||
@ -28,7 +33,27 @@ class ClassLoadingUtilsTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException::class)
|
||||
fun throwsExceptionWhenClassDoesNotContainProperConstructors() {
|
||||
val classes = loadClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface2::class.java)
|
||||
val classes = createInstancesOfClassesImplementing(BaseInterface::class.java.classLoader, BaseInterface2::class.java)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `thread context class loader is adjusted, during the function execution`() {
|
||||
val result = executeWithThreadContextClassLoader(temporaryClassLoader) {
|
||||
assertThat(Thread.currentThread().contextClassLoader).isEqualTo(temporaryClassLoader)
|
||||
true
|
||||
}
|
||||
|
||||
assertThat(result).isTrue()
|
||||
assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `thread context class loader is set to the initial, even in case of a failure`() {
|
||||
assertThatThrownBy { executeWithThreadContextClassLoader(temporaryClassLoader) {
|
||||
throw RuntimeException()
|
||||
} }.isInstanceOf(RuntimeException::class.java)
|
||||
|
||||
assertThat(Thread.currentThread().contextClassLoader).isNotEqualTo(temporaryClassLoader)
|
||||
}
|
||||
|
||||
}
|
Reference in New Issue
Block a user