CORDA-1325: Bootstrapper unable to whitelist two versions of the same contract simultaneously (#2980)

Also added unit tests
This commit is contained in:
Shams Asari 2018-04-24 10:51:24 +01:00 committed by GitHub
parent 10c559a3f3
commit 65525d74e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 202 additions and 41 deletions

View File

@ -31,6 +31,7 @@ import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.*
import java.lang.reflect.Field
import java.lang.reflect.Modifier
import java.math.BigDecimal
import java.net.HttpURLConnection
import java.net.HttpURLConnection.HTTP_OK
@ -306,6 +307,10 @@ fun TransactionBuilder.toLedgerTransaction(services: ServicesForResolution, seri
/** Convenience method to get the package name of a class literal. */
val KClass<*>.packageName: String get() = java.`package`.name
inline val Class<*>.isAbstractClass: Boolean get() = Modifier.isAbstract(modifiers)
inline val Class<*>.isConcreteClass: Boolean get() = !isInterface && !isAbstractClass
fun URI.toPath(): Path = Paths.get(this)
fun URL.toPath(): Path = toURI().toPath()

View File

@ -5,13 +5,10 @@ import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractClassName
import net.corda.core.contracts.UpgradedContract
import net.corda.core.contracts.UpgradedContractWithLegacyConstraint
import net.corda.core.internal.copyTo
import net.corda.core.internal.deleteIfExists
import net.corda.core.internal.logElapsedTime
import net.corda.core.internal.read
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.*
import org.slf4j.LoggerFactory
import java.io.InputStream
import java.lang.reflect.Modifier
import java.net.URLClassLoader
import java.nio.file.Files
import java.nio.file.Path
@ -22,22 +19,33 @@ import java.util.Collections.singleton
// relationships between those interfaces, therefore they have to be listed explicitly.
val coreContractClasses = setOf(Contract::class, UpgradedContractWithLegacyConstraint::class, UpgradedContract::class)
/**
* Scans the jar for contracts.
* @returns: found contract class names or null if none found
*/
fun scanJarForContracts(cordappJar: Path): List<ContractClassName> {
val scanResult = FastClasspathScanner()
// A set of a single element may look odd, but if this is removed "Path" which itself is an `Iterable`
// is getting broken into pieces to scan individually, which doesn't yield desired effect.
.overrideClasspath(singleton(cordappJar))
.scan()
val contracts = coreContractClasses.flatMap { contractClass -> scanResult.getNamesOfClassesImplementing(contractClass.qualifiedName) }.distinct()
interface ContractsJar {
val hash: SecureHash
fun scan(): List<ContractClassName>
}
// Only keep instantiable contracts
return URLClassLoader(arrayOf(cordappJar.toUri().toURL()), Contract::class.java.classLoader).use {
contracts.map(it::loadClass).filter { !it.isInterface && !Modifier.isAbstract(it.modifiers) }
}.map { it.name }
class ContractsJarFile(private val file: Path) : ContractsJar {
override val hash: SecureHash by lazy(LazyThreadSafetyMode.NONE, file::hash)
override fun scan(): List<ContractClassName> {
val scanResult = FastClasspathScanner()
// A set of a single element may look odd, but if this is removed "Path" which itself is an `Iterable`
// is getting broken into pieces to scan individually, which doesn't yield desired effect.
.overrideClasspath(singleton(file))
.scan()
val contractClassNames = coreContractClasses
.flatMap { scanResult.getNamesOfClassesImplementing(it.qualifiedName) }
.toSet()
return URLClassLoader(arrayOf(file.toUri().toURL()), Contract::class.java.classLoader).use { cl ->
contractClassNames.mapNotNull {
val contractClass = cl.loadClass(it)
// Only keep instantiable contracts
if (contractClass.isConcreteClass) contractClass.name else null
}
}
}
}
private val logger = LoggerFactory.getLogger("ClassloaderUtils")
@ -48,7 +56,7 @@ fun <T> withContractsInJar(jarInputStream: InputStream, withContracts: (List<Con
jarInputStream.copyTo(tempFile, StandardCopyOption.REPLACE_EXISTING)
val cordappJar = tempFile.toAbsolutePath()
val contracts = logElapsedTime("Contracts loading for '$cordappJar'", logger) {
scanJarForContracts(cordappJar)
ContractsJarFile(tempFile.toAbsolutePath()).scan()
}
return tempFile.read { withContracts(contracts, it) }
} finally {

View File

@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.network
import com.typesafe.config.ConfigFactory
import net.corda.cordform.CordformNode
import net.corda.core.contracts.ContractClassName
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.fork
@ -16,10 +17,11 @@ import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.ContractsJar
import net.corda.nodeapi.internal.ContractsJarFile
import net.corda.nodeapi.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.SignedNodeInfo
import net.corda.nodeapi.internal.network.NodeInfoFilesCopier.Companion.NODE_INFO_FILE_NAME_PREFIX
import net.corda.nodeapi.internal.scanJarForContracts
import net.corda.nodeapi.internal.serialization.AMQP_P2P_CONTEXT
import net.corda.nodeapi.internal.serialization.CordaSerializationMagic
import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl
@ -78,7 +80,7 @@ class NetworkBootstrapper {
println("Gathering notary identities")
val notaryInfos = gatherNotaryInfos(nodeInfoFiles)
println("Generating contract implementations whitelist")
val newWhitelist = generateWhitelist(existingNetParams, directory / EXCLUDE_WHITELIST_FILE_NAME, cordappJars)
val newWhitelist = generateWhitelist(existingNetParams, readExcludeWhitelist(directory), cordappJars.map(::ContractsJarFile))
val netParams = installNetworkParameters(notaryInfos, newWhitelist, existingNetParams, nodeDirs)
println("${if (existingNetParams == null) "New" else "Updated"} $netParams")
println("Bootstrapping complete!")
@ -228,29 +230,32 @@ class NetworkBootstrapper {
return networkParameters
}
private fun generateWhitelist(networkParameters: NetworkParameters?,
excludeWhitelistFile: Path,
cordappJars: List<Path>): Map<String, List<AttachmentId>> {
@VisibleForTesting
internal fun generateWhitelist(networkParameters: NetworkParameters?,
excludeContracts: List<ContractClassName>,
cordappJars: List<ContractsJar>): Map<ContractClassName, List<AttachmentId>> {
val existingWhitelist = networkParameters?.whitelistedContractImplementations ?: emptyMap()
val excludeContracts = readExcludeWhitelist(excludeWhitelistFile)
if (excludeContracts.isNotEmpty()) {
println("Exclude contracts from whitelist: ${excludeContracts.joinToString()}")
existingWhitelist.keys.forEach {
require(it !in excludeContracts) { "$it is already part of the existing whitelist and cannot be excluded." }
}
}
val newWhiteList = cordappJars.flatMap { cordappJar ->
val jarHash = cordappJar.hash
scanJarForContracts(cordappJar).map { contract -> contract to jarHash }
}.filter { (contractClassName, _) -> contractClassName !in excludeContracts }.toMap()
val newWhiteList = cordappJars
.flatMap { jar -> (jar.scan() - excludeContracts).map { it to jar.hash } }
.toMultiMap()
return (newWhiteList.keys + existingWhitelist.keys).map { contractClassName ->
val existing = existingWhitelist[contractClassName] ?: emptyList()
val newHash = newWhiteList[contractClassName]
contractClassName to (if (newHash == null || newHash in existing) existing else existing + newHash)
}.toMap()
return (newWhiteList.keys + existingWhitelist.keys).associateBy({ it }) {
val existingHashes = existingWhitelist[it] ?: emptyList()
val newHashes = newWhiteList[it] ?: emptyList()
(existingHashes + newHashes).distinct()
}
}
private fun readExcludeWhitelist(file: Path): List<String> {
private fun readExcludeWhitelist(directory: Path): List<String> {
val file = directory / EXCLUDE_WHITELIST_FILE_NAME
return if (file.exists()) file.readAllLines().map(String::trim) else emptyList()
}

View File

@ -0,0 +1,141 @@
package net.corda.nodeapi.internal.network
import net.corda.core.contracts.ContractClassName
import net.corda.core.crypto.SecureHash
import net.corda.core.node.services.AttachmentId
import net.corda.nodeapi.internal.ContractsJar
import net.corda.testing.common.internal.testNetworkParameters
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class NetworkBootstrapperTest {
@Test
fun `no jars against empty whitelist`() {
val whitelist = generateWhitelist(emptyMap(), emptyList(), emptyList())
assertThat(whitelist).isEmpty()
}
@Test
fun `no jars against single whitelist`() {
val existingWhitelist = mapOf("class1" to listOf(SecureHash.randomSHA256()))
val newWhitelist = generateWhitelist(existingWhitelist, emptyList(), emptyList())
assertThat(newWhitelist).isEqualTo(existingWhitelist)
}
@Test
fun `empty jar against empty whitelist`() {
val whitelist = generateWhitelist(emptyMap(), emptyList(), listOf(TestContractsJar(contractClassNames = emptyList())))
assertThat(whitelist).isEmpty()
}
@Test
fun `empty jar against single whitelist`() {
val existingWhitelist = mapOf("class1" to listOf(SecureHash.randomSHA256()))
val newWhitelist = generateWhitelist(existingWhitelist, emptyList(), listOf(TestContractsJar(contractClassNames = emptyList())))
assertThat(newWhitelist).isEqualTo(existingWhitelist)
}
@Test
fun `jar with single contract against empty whitelist`() {
val jar = TestContractsJar(contractClassNames = listOf("class1"))
val whitelist = generateWhitelist(emptyMap(), emptyList(), listOf(jar))
assertThat(whitelist).isEqualTo(mapOf(
"class1" to listOf(jar.hash)
))
}
@Test
fun `single contract jar against single whitelist of different contract`() {
val class1JarHash = SecureHash.randomSHA256()
val existingWhitelist = mapOf("class1" to listOf(class1JarHash))
val jar = TestContractsJar(contractClassNames = listOf("class2"))
val whitelist = generateWhitelist(existingWhitelist, emptyList(), listOf(jar))
assertThat(whitelist).isEqualTo(mapOf(
"class1" to listOf(class1JarHash),
"class2" to listOf(jar.hash)
))
}
@Test
fun `same jar with single contract`() {
val jarHash = SecureHash.randomSHA256()
val existingWhitelist = mapOf("class1" to listOf(jarHash))
val jar = TestContractsJar(hash = jarHash, contractClassNames = listOf("class1"))
val newWhitelist = generateWhitelist(existingWhitelist, emptyList(), listOf(jar))
assertThat(newWhitelist).isEqualTo(existingWhitelist)
}
@Test
fun `jar with updated contract`() {
val previousJarHash = SecureHash.randomSHA256()
val existingWhitelist = mapOf("class1" to listOf(previousJarHash))
val newContractsJar = TestContractsJar(contractClassNames = listOf("class1"))
val newWhitelist = generateWhitelist(existingWhitelist, emptyList(), listOf(newContractsJar))
assertThat(newWhitelist).isEqualTo(mapOf(
"class1" to listOf(previousJarHash, newContractsJar.hash)
))
}
@Test
fun `jar with one existing contract and one new one`() {
val previousJarHash = SecureHash.randomSHA256()
val existingWhitelist = mapOf("class1" to listOf(previousJarHash))
val newContractsJar = TestContractsJar(contractClassNames = listOf("class1", "class2"))
val newWhitelist = generateWhitelist(existingWhitelist, emptyList(), listOf(newContractsJar))
assertThat(newWhitelist).isEqualTo(mapOf(
"class1" to listOf(previousJarHash, newContractsJar.hash),
"class2" to listOf(newContractsJar.hash)
))
}
@Test
fun `two versions of the same contract`() {
val version1Jar = TestContractsJar(contractClassNames = listOf("class1"))
val version2Jar = TestContractsJar(contractClassNames = listOf("class1"))
val newWhitelist = generateWhitelist(emptyMap(), emptyList(), listOf(version1Jar, version2Jar))
assertThat(newWhitelist).isEqualTo(mapOf(
"class1" to listOf(version1Jar.hash, version2Jar.hash)
))
}
@Test
fun `jar with single new contract that's excluded`() {
val jar = TestContractsJar(contractClassNames = listOf("class1"))
val whitelist = generateWhitelist(emptyMap(), listOf("class1"), listOf(jar))
assertThat(whitelist).isEmpty()
}
@Test
fun `jar with two new contracts, one of which is excluded`() {
val jar = TestContractsJar(contractClassNames = listOf("class1", "class2"))
val whitelist = generateWhitelist(emptyMap(), listOf("class1"), listOf(jar))
assertThat(whitelist).isEqualTo(mapOf(
"class2" to listOf(jar.hash)
))
}
@Test
fun `jar with updated contract but it's excluded`() {
val existingWhitelist = mapOf("class1" to listOf(SecureHash.randomSHA256()))
val jar = TestContractsJar(contractClassNames = listOf("class1"))
assertThatIllegalArgumentException().isThrownBy {
generateWhitelist(existingWhitelist, listOf("class1"), listOf(jar))
}
}
private fun generateWhitelist(existingWhitelist: Map<String, List<AttachmentId>>,
excludeContracts: List<ContractClassName>,
contractJars: List<TestContractsJar>): Map<String, List<AttachmentId>> {
return NetworkBootstrapper().generateWhitelist(
testNetworkParameters(whitelistedContractImplementations = existingWhitelist),
excludeContracts,
contractJars
)
}
data class TestContractsJar(override val hash: SecureHash = SecureHash.randomSHA256(),
private val contractClassNames: List<ContractClassName>) : ContractsJar {
override fun scan(): List<ContractClassName> = contractClassNames
}
}

View File

@ -2,6 +2,7 @@ package net.corda.testing.common.internal
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NotaryInfo
import net.corda.core.node.services.AttachmentId
import java.time.Instant
fun testNetworkParameters(
@ -11,15 +12,16 @@ fun testNetworkParameters(
maxMessageSize: Int = 10485760,
// TODO: Make this configurable and consistence across driver, bootstrapper, demobench and NetworkMapServer
maxTransactionSize: Int = maxMessageSize,
whitelistedContractImplementations: Map<String, List<AttachmentId>> = emptyMap(),
epoch: Int = 1
): NetworkParameters {
return NetworkParameters(
minimumPlatformVersion = minimumPlatformVersion,
notaries = notaries,
modifiedTime = modifiedTime,
maxMessageSize = maxMessageSize,
maxTransactionSize = maxTransactionSize,
epoch = epoch,
whitelistedContractImplementations = emptyMap()
whitelistedContractImplementations = whitelistedContractImplementations,
modifiedTime = modifiedTime,
epoch = epoch
)
}