Merge branch 'release/os/4.8' into merge-release/os/4.7-release/os/4.8-2023-11-17-6

This commit is contained in:
Adel El-Beik 2023-11-18 18:10:07 +00:00 committed by GitHub
commit 548242e3cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
243 changed files with 7081 additions and 2449 deletions

View File

@ -1850,13 +1850,15 @@ public final class net.corda.core.crypto.CryptoUtils extends java.lang.Object
public static final boolean verify(java.security.PublicKey, byte[], byte[]) public static final boolean verify(java.security.PublicKey, byte[], byte[])
## ##
public interface net.corda.core.crypto.DigestAlgorithm public interface net.corda.core.crypto.DigestAlgorithm
@NotNull
public abstract byte[] componentDigest(byte[])
@NotNull @NotNull
public abstract byte[] digest(byte[]) public abstract byte[] digest(byte[])
@NotNull @NotNull
public abstract String getAlgorithm() public abstract String getAlgorithm()
public abstract int getDigestLength() public abstract int getDigestLength()
@NotNull @NotNull
public abstract byte[] preImageResistantDigest(byte[]) public abstract byte[] nonceDigest(byte[])
## ##
@CordaSerializable @CordaSerializable
public class net.corda.core.crypto.DigitalSignature extends net.corda.core.utilities.OpaqueBytes public class net.corda.core.crypto.DigitalSignature extends net.corda.core.utilities.OpaqueBytes
@ -5935,6 +5937,13 @@ public @interface net.corda.core.serialization.CordaSerializationTransformRename
public @interface net.corda.core.serialization.CordaSerializationTransformRenames public @interface net.corda.core.serialization.CordaSerializationTransformRenames
public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value() public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value()
## ##
public interface net.corda.core.serialization.CustomSerializationScheme
@NotNull
public abstract T deserialize(net.corda.core.utilities.ByteSequence, Class<T>, net.corda.core.serialization.SerializationSchemeContext)
public abstract int getSchemeId()
@NotNull
public abstract net.corda.core.utilities.ByteSequence serialize(T, net.corda.core.serialization.SerializationSchemeContext)
##
public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization
public abstract int version() public abstract int version()
## ##
@ -6076,6 +6085,13 @@ public static final class net.corda.core.serialization.SerializationFactory$Comp
@NotNull @NotNull
public final net.corda.core.serialization.SerializationFactory getDefaultFactory() public final net.corda.core.serialization.SerializationFactory getDefaultFactory()
## ##
@DoNotImplement
public interface net.corda.core.serialization.SerializationSchemeContext
@NotNull
public abstract ClassLoader getDeserializationClassLoader()
@NotNull
public abstract net.corda.core.serialization.ClassWhitelist getWhitelist()
##
public interface net.corda.core.serialization.SerializationToken public interface net.corda.core.serialization.SerializationToken
@NotNull @NotNull
public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext) public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext)

View File

@ -70,6 +70,7 @@ pipeline {
stage('Compile') { stage('Compile') {
steps { steps {
dir(sameAgentFolder) { dir(sameAgentFolder) {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,

View File

@ -13,13 +13,13 @@
* the branch name of origin branch, it should match the current branch * the branch name of origin branch, it should match the current branch
* and it acts as a fail-safe inside {@code forwardMerger} pipeline * and it acts as a fail-safe inside {@code forwardMerger} pipeline
*/ */
String originBranch = 'release/os/4.7' String originBranch = 'release/os/4.8'
/** /**
* the branch name of target branch, it should be the branch with the next version * the branch name of target branch, it should be the branch with the next version
* after the one in current branch. * after the one in current branch.
*/ */
String targetBranch = 'release/os/4.8' String targetBranch = 'release/os/4.9'
/** /**
* Forward merge any changes between #originBranch and #targetBranch * Forward merge any changes between #originBranch and #targetBranch

View File

@ -56,6 +56,7 @@ pipeline {
stage('Unit Tests') { stage('Unit Tests') {
agent { label 'mswin' } agent { label 'mswin' }
steps { steps {
authenticateGradleWrapper()
bat "./gradlew --no-daemon " + bat "./gradlew --no-daemon " +
"--stacktrace " + "--stacktrace " +
"-Pcompilation.warningsAsErrors=false " + "-Pcompilation.warningsAsErrors=false " +

View File

@ -50,6 +50,7 @@ pipeline {
stages { stages {
stage('Compile') { stage('Compile') {
steps { steps {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,

View File

@ -3,6 +3,7 @@
* Jenkins pipeline to build Corda OS release branches and tags. * Jenkins pipeline to build Corda OS release branches and tags.
* PLEASE NOTE: we DO want to run a build for each commit!!! * PLEASE NOTE: we DO want to run a build for each commit!!!
*/ */
@Library('corda-shared-build-pipeline-steps')
/** /**
* Sense environment * Sense environment
@ -47,6 +48,7 @@ pipeline {
stages { stages {
stage('Unit Tests') { stage('Unit Tests') {
steps { steps {
authenticateGradleWrapper()
sh "./gradlew clean --continue test --info -Ptests.failFast=true" sh "./gradlew clean --continue test --info -Ptests.failFast=true"
} }
} }

View File

@ -30,6 +30,7 @@ pipeline {
stages { stages {
stage('Detekt check') { stage('Detekt check') {
steps { steps {
authenticateGradleWrapper()
sh "./gradlew --no-daemon --parallel --build-cache clean detekt" sh "./gradlew --no-daemon --parallel --build-cache clean detekt"
} }
} }
@ -54,6 +55,7 @@ pipeline {
GRADLE_USER_HOME = "/host_tmp/gradle" GRADLE_USER_HOME = "/host_tmp/gradle"
} }
steps { steps {
authenticateGradleWrapper()
sh 'mkdir -p ${GRADLE_USER_HOME}' sh 'mkdir -p ${GRADLE_USER_HOME}'
snykDeltaScan(env.SNYK_API_TOKEN, env.C4_OS_SNYK_ORG_ID) snykDeltaScan(env.SNYK_API_TOKEN, env.C4_OS_SNYK_ORG_ID)
} }

View File

@ -33,6 +33,7 @@ pipeline {
stage('Publish Archived API Docs to Artifactory') { stage('Publish Archived API Docs to Artifactory') {
when { tag pattern: /^docs-release-os-V(\d+\.\d+)(\.\d+){0,1}(-GA){0,1}(-\d{4}-\d\d-\d\d-\d{4}){0,1}$/, comparator: 'REGEXP' } when { tag pattern: /^docs-release-os-V(\d+\.\d+)(\.\d+){0,1}(-GA){0,1}(-\d{4}-\d\d-\d\d-\d{4}){0,1}$/, comparator: 'REGEXP' }
steps { steps {
authenticateGradleWrapper()
sh "./gradlew :clean :docs:artifactoryPublish -DpublishApiDocs" sh "./gradlew :clean :docs:artifactoryPublish -DpublishApiDocs"
} }
} }

View File

@ -29,6 +29,7 @@ pipeline {
stages { stages {
stage('Publish to Artifactory') { stage('Publish to Artifactory') {
steps { steps {
authenticateGradleWrapper()
rtServer ( rtServer (
id: 'R3-Artifactory', id: 'R3-Artifactory',
url: 'https://software.r3.com/artifactory', url: 'https://software.r3.com/artifactory',

View File

@ -70,6 +70,7 @@ pipeline {
stages { stages {
stage('Compile') { stage('Compile') {
steps { steps {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,
@ -168,6 +169,7 @@ pipeline {
} }
stage('Recompile') { stage('Recompile') {
steps { steps {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,

View File

@ -16,6 +16,7 @@ jobs:
with: with:
jiraBaseUrl: https://r3-cev.atlassian.net jiraBaseUrl: https://r3-cev.atlassian.net
project: CORDA project: CORDA
squad: Corda
issuetype: Bug issuetype: Bug
summary: ${{ github.event.issue.title }} summary: ${{ github.event.issue.title }}
labels: community labels: community

2
Jenkinsfile vendored
View File

@ -58,6 +58,7 @@ pipeline {
stages { stages {
stage('Compile') { stage('Compile') {
steps { steps {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,
@ -100,6 +101,7 @@ pipeline {
} }
stage('Recompile') { stage('Recompile') {
steps { steps {
authenticateGradleWrapper()
sh script: [ sh script: [
'./gradlew', './gradlew',
COMMON_GRADLE_PARAMS, COMMON_GRADLE_PARAMS,

View File

@ -101,7 +101,7 @@ buildscript {
ext.hibernate_version = '5.4.32.Final' ext.hibernate_version = '5.4.32.Final'
ext.h2_version = '1.4.199' // Update docs if renamed or removed. ext.h2_version = '1.4.199' // Update docs if renamed or removed.
ext.rxjava_version = '1.3.8' ext.rxjava_version = '1.3.8'
ext.dokka_version = '0.9.17' ext.dokka_version = '0.10.1'
ext.eddsa_version = '0.3.0' ext.eddsa_version = '0.3.0'
ext.dependency_checker_version = '5.2.0' ext.dependency_checker_version = '5.2.0'
ext.commons_collections_version = '4.3' ext.commons_collections_version = '4.3'

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc package net.corda.client.rpc
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.ReconnectingCordaRPCOps import net.corda.client.rpc.internal.ReconnectingCordaRPCOps
import net.corda.client.rpc.internal.SerializationEnvironmentHelper import net.corda.client.rpc.internal.SerializationEnvironmentHelper
@ -52,7 +53,7 @@ class CordaRPCConnection private constructor(
sslConfiguration: ClientRpcSslOptions? = null, sslConfiguration: ClientRpcSslOptions? = null,
classLoader: ClassLoader? = null classLoader: ClassLoader? = null
): CordaRPCConnection { ): CordaRPCConnection {
val observersPool: ExecutorService = Executors.newCachedThreadPool() val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver"))
return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps( return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps(
addresses, addresses,
username, username,

View File

@ -17,7 +17,6 @@ import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcConnectorTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcConnectorTcpTransport
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcConnectorTcpTransportsFromList
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcInternalClientTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcInternalClientTcpTransport
import net.corda.nodeapi.internal.RoundRobinConnectionPolicy import net.corda.nodeapi.internal.RoundRobinConnectionPolicy
import net.corda.nodeapi.internal.config.SslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration
@ -61,8 +60,12 @@ class RPCClient<I : RPCOps>(
sslConfiguration: ClientRpcSslOptions? = null, sslConfiguration: ClientRpcSslOptions? = null,
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT,
serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT
) : this(rpcConnectorTcpTransport(haAddressPool.first(), sslConfiguration), ) : this(
configuration, serializationContext, rpcConnectorTcpTransportsFromList(haAddressPool, sslConfiguration)) rpcConnectorTcpTransport(haAddressPool.first(), sslConfiguration),
configuration,
serializationContext,
haAddressPool.map { rpcConnectorTcpTransport(it, sslConfiguration) }
)
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc.internal package net.corda.client.rpc.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.ConnectionFailureException import net.corda.client.rpc.ConnectionFailureException
import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration import net.corda.client.rpc.CordaRPCClientConfiguration
@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor(
ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps
} }
} }
private val retryFlowsPool = Executors.newScheduledThreadPool(1) private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry"))
/** /**
* This function runs a flow and retries until it completes successfully. * This function runs a flow and retries until it completes successfully.
* *

View File

@ -9,4 +9,4 @@ package net.corda.common.logging
* (originally added to source control for ease of use) * (originally added to source control for ease of use)
*/ */
internal const val CURRENT_MAJOR_RELEASE = "4.7-SNAPSHOT" internal const val CURRENT_MAJOR_RELEASE = "4.8-SNAPSHOT"

View File

@ -2,7 +2,7 @@
# because some versions here need to be matched by app authors in # because some versions here need to be matched by app authors in
# their own projects. So don't get fancy with syntax! # their own projects. So don't get fancy with syntax!
cordaVersion=4.7 cordaVersion=4.8
versionSuffix=SNAPSHOT versionSuffix=SNAPSHOT
gradlePluginsVersion=5.0.12 gradlePluginsVersion=5.0.12
kotlinVersion=1.2.71 kotlinVersion=1.2.71
@ -11,7 +11,7 @@ java8MinUpdateVersion=171
# When incrementing platformVersion make sure to update # # When incrementing platformVersion make sure to update #
# net.corda.core.internal.CordaUtilsKt.PLATFORM_VERSION as well. # # net.corda.core.internal.CordaUtilsKt.PLATFORM_VERSION as well. #
# ***************************************************************# # ***************************************************************#
platformVersion=9 platformVersion=10
guavaVersion=28.0-jre guavaVersion=28.0-jre
# Quasar version to use with Java 8: # Quasar version to use with Java 8:
quasarVersion=0.7.15_r3 quasarVersion=0.7.15_r3
@ -21,7 +21,7 @@ jdkClassifier11=jdk11
dockerJavaVersion=3.2.5 dockerJavaVersion=3.2.5
proguardVersion=6.1.1 proguardVersion=6.1.1
bouncycastleVersion=1.68 bouncycastleVersion=1.68
classgraphVersion=4.8.90 classgraphVersion=4.8.135
disruptorVersion=3.4.2 disruptorVersion=3.4.2
typesafeConfigVersion=1.3.4 typesafeConfigVersion=1.3.4
jsr305Version=3.0.2 jsr305Version=3.0.2

View File

@ -23,7 +23,10 @@ def javaHome = System.getProperty('java.home')
def jarBaseName = "corda-${project.name}".toString() def jarBaseName = "corda-${project.name}".toString()
configurations { configurations {
deterministicLibraries.extendsFrom api deterministicLibraries {
canBeConsumed = false
extendsFrom api
}
deterministicArtifacts.extendsFrom deterministicLibraries deterministicArtifacts.extendsFrom deterministicLibraries
} }
@ -59,7 +62,7 @@ def originalJar = coreJarTask.map { it.outputs.files.singleFile }
def patchCore = tasks.register('patchCore', Zip) { def patchCore = tasks.register('patchCore', Zip) {
dependsOn coreJarTask dependsOn coreJarTask
destinationDirectory = file("$buildDir/source-libs") destinationDirectory = layout.buildDirectory.dir('source-libs')
metadataCharset 'UTF-8' metadataCharset 'UTF-8'
archiveClassifier = 'transient' archiveClassifier = 'transient'
archiveExtension = 'jar' archiveExtension = 'jar'
@ -169,7 +172,7 @@ def determinise = tasks.register('determinise', ProGuardTask) {
def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask) def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask)
def metafix = tasks.register('metafix', MetaFixerTask) { def metafix = tasks.register('metafix', MetaFixerTask) {
outputDir file("$buildDir/libs") outputDir = layout.buildDirectory.dir('libs')
jars determinise jars determinise
suffix "" suffix ""

View File

@ -55,12 +55,16 @@ abstract class SerializationFactory {
* Change the current context inside the block to that supplied. * Change the current context inside the block to that supplied.
*/ */
fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T { fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T {
return if (context == null) {
block()
} else {
val priorContext = _currentContext val priorContext = _currentContext
if (context != null) _currentContext = context _currentContext = context
try { try {
return block() block()
} finally { } finally {
if (context != null) _currentContext = priorContext _currentContext = priorContext
}
} }
} }

View File

@ -3,7 +3,9 @@ plugins {
} }
configurations { configurations {
testData testData {
canBeResolved = false
}
} }
dependencies { dependencies {

View File

@ -9,7 +9,12 @@ apply from: "${rootProject.projectDir}/deterministic.gradle"
description 'Test utilities for deterministic contract verification' description 'Test utilities for deterministic contract verification'
configurations { configurations {
deterministicArtifacts deterministicArtifacts {
canBeResolved = false
}
// Compile against the deterministic artifacts to ensure that we use only the deterministic API subset.
compileOnly.extendsFrom deterministicArtifacts
runtimeArtifacts.extendsFrom api runtimeArtifacts.extendsFrom api
} }
@ -20,8 +25,6 @@ dependencies {
runtimeArtifacts project(':serialization') runtimeArtifacts project(':serialization')
runtimeArtifacts project(':core') runtimeArtifacts project(':core')
// Compile against the deterministic artifacts to ensure that we use only the deterministic API subset.
compileOnly configurations.deterministicArtifacts
api "junit:junit:$junit_version" api "junit:junit:$junit_version"
runtimeOnly "org.junit.vintage:junit-vintage-engine:$junit_vintage_version" runtimeOnly "org.junit.vintage:junit-vintage-engine:$junit_vintage_version"
} }

View File

@ -13,7 +13,7 @@ import net.corda.core.transactions.WireTransaction
@Suppress("MemberVisibilityCanBePrivate") @Suppress("MemberVisibilityCanBePrivate")
//TODO the use of deprecated toLedgerTransaction need to be revisited as resolveContractAttachment requires attachments of the transactions which created input states... //TODO the use of deprecated toLedgerTransaction need to be revisited as resolveContractAttachment requires attachments of the transactions which created input states...
//TODO ...to check contract version non downgrade rule, curretly dummy Attachment if not fund is used which sets contract version to '1' //TODO ...to check contract version non downgrade rule, currently dummy Attachment if not fund is used which sets contract version to '1'
@CordaSerializable @CordaSerializable
class TransactionVerificationRequest(val wtxToVerify: SerializedBytes<WireTransaction>, class TransactionVerificationRequest(val wtxToVerify: SerializedBytes<WireTransaction>,
val dependencies: Array<SerializedBytes<WireTransaction>>, val dependencies: Array<SerializedBytes<WireTransaction>>,

View File

@ -5,8 +5,11 @@ import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.crypto.internal.DigestAlgorithmFactory
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.BLAKE2s256DigestAlgorithm
import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm
import net.corda.core.node.NotaryInfo import net.corda.core.node.NotaryInfo
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
@ -49,12 +52,19 @@ class PartialMerkleTreeTest(private var digestService: DigestService) {
val MINI_CORP get() = miniCorp.party val MINI_CORP get() = miniCorp.party
val MINI_CORP_PUBKEY get() = miniCorp.publicKey val MINI_CORP_PUBKEY get() = miniCorp.publicKey
init {
DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name)
DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name)
}
@JvmStatic @JvmStatic
@Parameterized.Parameters @Parameterized.Parameters
fun data(): Collection<DigestService> = listOf( fun data(): Collection<DigestService> = listOf(
DigestService.sha2_256, DigestService.sha2_256,
DigestService.sha2_384, DigestService.sha2_384,
DigestService.sha2_512 DigestService.sha2_512,
DigestService("BLAKE_TEST"),
DigestService("SHA256-BLAKE2S256-TEST")
) )
} }

View File

@ -140,7 +140,7 @@ class PartialMerkleTreeWithNamedHashTest {
fun `building Merkle tree one node`() { fun `building Merkle tree one node`() {
val node = 'a'.serialize().sha2_384() val node = 'a'.serialize().sha2_384()
val mt = MerkleTree.getMerkleTree(listOf(node), DigestService.sha2_384) val mt = MerkleTree.getMerkleTree(listOf(node), DigestService.sha2_384)
assertEquals(node, mt.hash) assertNotEquals(node, mt.hash)
} }
@Test(timeout=300_000) @Test(timeout=300_000)

View File

@ -0,0 +1,29 @@
package net.corda.coretests.crypto.internal
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.testing.core.createCRL
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class ProviderMapTest {
// https://github.com/corda/corda/pull/3997
@Test(timeout = 300_000)
fun `verify CRL algorithms`() {
val crl = createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "SHA256withECDSA"
)
// This should pass.
crl.verify(DEV_ROOT_CA.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "EC"
)
}.withMessage("Unknown signature type requested: EC")
}
}

View File

@ -55,7 +55,8 @@ class AttachmentsClassLoaderSerializationTests {
arrayOf(isolatedId, att1, att2).map { storage.openAttachment(it)!! }, arrayOf(isolatedId, att1, att2).map { storage.openAttachment(it)!! },
testNetworkParameters(), testNetworkParameters(),
SecureHash.zeroHash, SecureHash.zeroHash,
{ attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { classLoader -> { attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { serializationContext ->
val classLoader = serializationContext.deserializationClassLoader
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, classLoader) val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, classLoader)
val contract = contractClass.getDeclaredConstructor().newInstance() as Contract val contract = contractClass.getDeclaredConstructor().newInstance() as Contract
assertEquals("helloworld", contract.declaredField<Any?>("magicString").value) assertEquals("helloworld", contract.declaredField<Any?>("magicString").value)

View File

@ -24,8 +24,9 @@ import net.corda.core.node.NetworkParameters
import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.AttachmentId
import net.corda.core.serialization.internal.AttachmentsClassLoader import net.corda.core.serialization.internal.AttachmentsClassLoader
import net.corda.core.serialization.internal.AttachmentsClassLoaderCacheImpl import net.corda.core.serialization.internal.AttachmentsClassLoaderCacheImpl
import net.corda.testing.common.internal.testNetworkParameters import net.corda.core.transactions.LedgerTransaction
import net.corda.node.services.attachments.NodeAttachmentTrustCalculator import net.corda.node.services.attachments.NodeAttachmentTrustCalculator
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
@ -74,7 +75,7 @@ class AttachmentsClassLoaderTests {
val BOB = TestIdentity(BOB_NAME, 80).party val BOB = TestIdentity(BOB_NAME, 80).party
val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20)
val DUMMY_NOTARY get() = dummyNotary.party val DUMMY_NOTARY get() = dummyNotary.party
val PROGRAM_ID: String = "net.corda.testing.contracts.MyDummyContract" const val PROGRAM_ID = "net.corda.testing.contracts.MyDummyContract"
} }
@Rule @Rule
@ -89,7 +90,7 @@ class AttachmentsClassLoaderTests {
private lateinit var internalStorage: InternalMockAttachmentStorage private lateinit var internalStorage: InternalMockAttachmentStorage
private lateinit var attachmentTrustCalculator: AttachmentTrustCalculator private lateinit var attachmentTrustCalculator: AttachmentTrustCalculator
private val networkParameters = testNetworkParameters() private val networkParameters = testNetworkParameters()
private val cacheFactory = TestingNamedCacheFactory() private val cacheFactory = TestingNamedCacheFactory(1)
private fun createClassloader( private fun createClassloader(
attachment: AttachmentId, attachment: AttachmentId,
@ -541,6 +542,50 @@ class AttachmentsClassLoaderTests {
} }
} }
@Test(timeout=300_000)
fun `class loader not closed after cache starts evicting`() {
tempFolder.root.toPath().let { path ->
val transactions = mutableListOf<LedgerTransaction>()
val iterations = 10
val baseOutState = TransactionState(DummyContract.SingleOwnerState(0, ALICE), PROGRAM_ID, DUMMY_NOTARY, constraint = AlwaysAcceptAttachmentConstraint)
val inputs = emptyList<StateAndRef<*>>()
val outputs = listOf(baseOutState, baseOutState.copy(notary = ALICE), baseOutState.copy(notary = BOB))
val commands = emptyList<CommandWithParties<CommandData>>()
val content = createContractString(PROGRAM_ID)
val timeWindow: TimeWindow? = null
val attachmentsClassLoaderCache = AttachmentsClassLoaderCacheImpl(cacheFactory)
val contractJarPath = ContractJarTestUtils.makeTestContractJar(path, PROGRAM_ID, content = content)
val attachments = createAttachments(contractJarPath)
for(i in 1 .. iterations) {
val id = SecureHash.randomSHA256()
val privacySalt = PrivacySalt()
val transaction = createLedgerTransaction(
inputs,
outputs,
commands,
attachments,
id,
null,
timeWindow,
privacySalt,
testNetworkParameters(),
emptyList(),
isAttachmentTrusted = { true },
attachmentsClassLoaderCache = attachmentsClassLoaderCache
)
transactions.add(transaction)
System.gc()
Thread.sleep(1)
}
transactions.forEach {
it.verify()
}
}
}
private fun createContractString(contractName: String, versionSeed: Int = 0): String { private fun createContractString(contractName: String, versionSeed: Int = 0): String {
val pkgs = contractName.split(".") val pkgs = contractName.split(".")
val className = pkgs.last() val className = pkgs.last()
@ -563,7 +608,7 @@ class AttachmentsClassLoaderTests {
} }
""".trimIndent() """.trimIndent()
System.out.println(output) println(output)
return output return output
} }
@ -571,6 +616,7 @@ class AttachmentsClassLoaderTests {
val attachment = object : AbstractAttachment({contractJarPath.inputStream().readBytes()}, uploader = "app") { val attachment = object : AbstractAttachment({contractJarPath.inputStream().readBytes()}, uploader = "app") {
@Suppress("OverridingDeprecatedMember") @Suppress("OverridingDeprecatedMember")
@Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.")
override val signers: List<Party> = emptyList() override val signers: List<Party> = emptyList()
override val signerKeys: List<PublicKey> = emptyList() override val signerKeys: List<PublicKey> = emptyList()
override val size: Int = 1234 override val size: Int = 1234
@ -581,6 +627,7 @@ class AttachmentsClassLoaderTests {
return listOf( return listOf(
object : AbstractAttachment({ISOLATED_CONTRACTS_JAR_PATH.openStream().readBytes()}, uploader = "app") { object : AbstractAttachment({ISOLATED_CONTRACTS_JAR_PATH.openStream().readBytes()}, uploader = "app") {
@Suppress("OverridingDeprecatedMember") @Suppress("OverridingDeprecatedMember")
@Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.")
override val signers: List<Party> = emptyList() override val signers: List<Party> = emptyList()
override val signerKeys: List<PublicKey> = emptyList() override val signerKeys: List<PublicKey> = emptyList()
override val size: Int = 1234 override val size: Int = 1234
@ -589,6 +636,7 @@ class AttachmentsClassLoaderTests {
object : AbstractAttachment({fakeAttachment("importantDoc.pdf", "I am a pdf!").inputStream().readBytes() object : AbstractAttachment({fakeAttachment("importantDoc.pdf", "I am a pdf!").inputStream().readBytes()
}, uploader = "app") { }, uploader = "app") {
@Suppress("OverridingDeprecatedMember") @Suppress("OverridingDeprecatedMember")
@Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.")
override val signers: List<Party> = emptyList() override val signers: List<Party> = emptyList()
override val signerKeys: List<PublicKey> = emptyList() override val signerKeys: List<PublicKey> = emptyList()
override val size: Int = 1234 override val size: Int = 1234

View File

@ -0,0 +1,186 @@
package net.corda.coretests.transactions
import net.corda.core.contracts.ComponentGroupEnum.COMMANDS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.INPUTS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.NOTARY_GROUP
import net.corda.core.contracts.ComponentGroupEnum.OUTPUTS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.SIGNERS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.TIMEWINDOW_GROUP
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.PrivacySalt
import net.corda.core.contracts.TimeWindow
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.DigestService
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.internal.DigestAlgorithmFactory
import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm
import net.corda.core.internal.accessAvailableComponentHashes
import net.corda.core.internal.accessAvailableComponentNonces
import net.corda.core.internal.accessGroupMerkleRoots
import net.corda.core.serialization.serialize
import net.corda.core.transactions.ComponentGroup
import net.corda.core.transactions.WireTransaction
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.time.Instant
import kotlin.test.assertEquals
class MerkleTreeAgilityTest {
private companion object {
val DUMMY_KEY_1 = generateKeyPair()
val DUMMY_KEY_2 = generateKeyPair()
val BOB = TestIdentity(BOB_NAME, 80).party
val DUMMY_NOTARY = TestIdentity(DUMMY_NOTARY_NAME, 20).party
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private val dummyOutState = TransactionState(DummyState(0), DummyContract.PROGRAM_ID, DUMMY_NOTARY)
private val stateRef1 = StateRef(SecureHash.randomSHA256(), 0)
private val stateRef2 = StateRef(SecureHash.randomSHA256(), 1)
private val stateRef3 = StateRef(SecureHash.randomSHA256(), 2)
private val stateRef4 = StateRef(SecureHash.randomSHA256(), 3)
private val singleInput = listOf(stateRef1) // 1 elements.
private val threeInputs = listOf(stateRef1, stateRef2, stateRef3) // 3 elements.
private val fourInputs = listOf(stateRef1, stateRef2, stateRef3, stateRef4) // 4 elements.
private val outputs = listOf(dummyOutState, dummyOutState.copy(notary = BOB)) // 2 elements.
private val commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)) // 1 element.
private val notary = DUMMY_NOTARY
private val timeWindow = TimeWindow.fromOnly(Instant.now())
private val privacySalt: PrivacySalt = PrivacySalt()
private val singleInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, singleInput.map { it.serialize() }) }
private val threeInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, threeInputs.map { it.serialize() }) }
private val fourInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, fourInputs.map { it.serialize() }) }
private val outputGroup by lazy { ComponentGroup(OUTPUTS_GROUP.ordinal, outputs.map { it.serialize() }) }
private val commandGroup by lazy { ComponentGroup(COMMANDS_GROUP.ordinal, commands.map { it.value.serialize() }) }
private val notaryGroup by lazy { ComponentGroup(NOTARY_GROUP.ordinal, listOf(notary.serialize())) }
private val timeWindowGroup by lazy { ComponentGroup(TIMEWINDOW_GROUP.ordinal, listOf(timeWindow.serialize())) }
private val signersGroup by lazy { ComponentGroup(SIGNERS_GROUP.ordinal, commands.map { it.signers.serialize() }) }
private val componentGroupsSingle by lazy {
listOf(singleInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val componentGroupsFourInputs by lazy {
listOf(fourInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val componentGroupsThreeInputs by lazy {
listOf(threeInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val defaultDigestService = DigestService.sha2_256
private val customDigestService = DigestService("SHA256-BLAKE2S256-TEST")
@Before
fun before() {
DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name)
}
@Test(timeout = 300_000)
fun `component nonces are correct for custom preimage resistant hash algo`() {
val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService)
val expected = componentGroupsFourInputs.associate {
it.groupIndex to it.components.mapIndexed { componentIndexInGroup, _ ->
customDigestService.computeNonce(privacySalt, it.groupIndex, componentIndexInGroup)
}
}
assertEquals(expected, wireTransaction.accessAvailableComponentNonces())
}
@Test(timeout = 300_000)
fun `component nonces are correct for default SHA256 hash algo`() {
val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt)
val expected = componentGroupsFourInputs.associate {
it.groupIndex to it.components.mapIndexed { componentIndexInGroup, componentBytes ->
defaultDigestService.componentHash(componentBytes, privacySalt, it.groupIndex, componentIndexInGroup)
}
}
assertEquals(expected, wireTransaction.accessAvailableComponentNonces())
}
@Test(timeout = 300_000)
fun `custom algorithm transaction pads leaf in single component component group`() {
val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val expected = customDigestService.hash(inputsTreeLeaves[0].bytes + customDigestService.zeroHash.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction does not pad leaf in single component component group`() {
val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val expected = inputsTreeLeaves[0]
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `custom algorithm transaction has expected root for four components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes)
val expected = customDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction has expected root for four components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes)
val expected = defaultDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `custom algorithm transaction has expected root for three components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + customDigestService.zeroHash.bytes)
val expected = customDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction has expected root for three components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + defaultDigestService.zeroHash.bytes)
val expected = defaultDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
}

View File

@ -3,7 +3,16 @@ package net.corda.coretests.transactions
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.* import net.corda.core.contracts.Command
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.HashAttachmentConstraint
import net.corda.core.contracts.PrivacySalt
import net.corda.core.contracts.SignatureAttachmentConstraint
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TimeWindow
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.DigestService import net.corda.core.crypto.DigestService
@ -20,11 +29,16 @@ import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.NetworkParametersService
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.coretesting.internal.rigorousMock
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState import net.corda.testing.contracts.DummyState
import net.corda.testing.core.* import net.corda.testing.core.ALICE_NAME
import net.corda.coretesting.internal.rigorousMock import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Assert.assertFalse import org.junit.Assert.assertFalse
@ -35,6 +49,7 @@ import org.junit.Rule
import org.junit.Test import org.junit.Test
import java.security.PublicKey import java.security.PublicKey
import java.time.Instant import java.time.Instant
import kotlin.test.assertFailsWith
class TransactionBuilderTest { class TransactionBuilderTest {
@Rule @Rule
@ -299,4 +314,22 @@ class TransactionBuilderTest {
HashAgility.init() HashAgility.init()
} }
} }
@Test(timeout=300_000)
fun `toWireTransaction fails if no scheme is registered with schemeId`() {
val outputState = TransactionState(
data = DummyState(),
contract = DummyContract.PROGRAM_ID,
notary = notary,
constraint = HashAttachmentConstraint(contractAttachmentId)
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, notary.owningKey)
val schemeId = 7
assertFailsWith<UnsupportedOperationException>("Could not find custom serialization scheme with SchemeId = $schemeId.") {
builder.toWireTransaction(services, schemeId)
}
}
} }

View File

@ -21,6 +21,7 @@ import net.corda.testing.internal.createWireTransaction
import net.corda.testing.internal.fakeAttachment import net.corda.testing.internal.fakeAttachment
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.testing.internal.TestingNamedCacheFactory import net.corda.testing.internal.TestingNamedCacheFactory
import org.assertj.core.api.Assertions.fail
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
@ -36,6 +37,7 @@ import kotlin.test.assertNotEquals
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class TransactionTests(private val digestService : DigestService) { class TransactionTests(private val digestService : DigestService) {
private companion object { private companion object {
const val ISOLATED_JAR = "isolated-4.0.jar"
val DUMMY_KEY_1 = generateKeyPair() val DUMMY_KEY_1 = generateKeyPair()
val DUMMY_KEY_2 = generateKeyPair() val DUMMY_KEY_2 = generateKeyPair()
val DUMMY_CASH_ISSUER_KEY = entropyToKeyPair(BigInteger.valueOf(10)) val DUMMY_CASH_ISSUER_KEY = entropyToKeyPair(BigInteger.valueOf(10))
@ -200,15 +202,15 @@ class TransactionTests(private val digestService : DigestService) {
val outputs = listOf(outState) val outputs = listOf(outState)
val commands = emptyList<CommandWithParties<CommandData>>() val commands = emptyList<CommandWithParties<CommandData>>()
val attachments = listOf(object : AbstractAttachment({ val attachments = listOf(ContractAttachment(object : AbstractAttachment({
AttachmentsClassLoaderTests::class.java.getResource("isolated-4.0.jar").openStream().readBytes() (AttachmentsClassLoaderTests::class.java.getResource(ISOLATED_JAR) ?: fail("Missing $ISOLATED_JAR")).openStream().readBytes()
}, TESTDSL_UPLOADER) { }, TESTDSL_UPLOADER) {
@Suppress("OverridingDeprecatedMember") @Suppress("OverridingDeprecatedMember")
override val signers: List<Party> = emptyList() override val signers: List<Party> = emptyList()
override val signerKeys: List<PublicKey> = emptyList() override val signerKeys: List<PublicKey> = emptyList()
override val size: Int = 1234 override val size: Int = 1234
override val id: SecureHash = SecureHash.zeroHash override val id: SecureHash = SecureHash.zeroHash
}) }, DummyContract.PROGRAM_ID))
val id = digestService.randomHash() val id = digestService.randomHash()
val timeWindow: TimeWindow? = null val timeWindow: TimeWindow? = null
val privacySalt = PrivacySalt(digestService.digestLength) val privacySalt = PrivacySalt(digestService.digestLength)

View File

@ -12,6 +12,10 @@ description 'Corda core'
// required by DJVM and Avian JVM (for running inside the SGX enclave) which only supports Java 8. // required by DJVM and Avian JVM (for running inside the SGX enclave) which only supports Java 8.
targetCompatibility = VERSION_1_8 targetCompatibility = VERSION_1_8
sourceSets {
obfuscator
}
configurations { configurations {
integrationTestCompile.extendsFrom testCompile integrationTestCompile.extendsFrom testCompile
integrationTestRuntimeOnly.extendsFrom testRuntimeOnly integrationTestRuntimeOnly.extendsFrom testRuntimeOnly
@ -22,6 +26,9 @@ configurations {
dependencies { dependencies {
obfuscatorImplementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version"
testImplementation sourceSets.obfuscator.output
testImplementation "org.junit.jupiter:junit-jupiter-api:${junit_jupiter_version}" testImplementation "org.junit.jupiter:junit-jupiter-api:${junit_jupiter_version}"
testImplementation "junit:junit:$junit_version" testImplementation "junit:junit:$junit_version"
testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}" testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}"
@ -110,7 +117,16 @@ configurations {
} }
test{ processTestResources {
inputs.files(jar)
into("zip") {
from(jar) {
rename { "core.jar" }
}
}
}
test {
maxParallelForks = (System.env.CORDA_CORE_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_CORE_TESTING_FORKS".toInteger() maxParallelForks = (System.env.CORDA_CORE_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_CORE_TESTING_FORKS".toInteger()
} }
@ -163,3 +179,10 @@ scanApi {
publish { publish {
name jar.baseName name jar.baseName
} }
tasks.register("writeTestResources", JavaExec) {
classpath sourceSets.obfuscator.output
classpath sourceSets.obfuscator.runtimeClasspath
main 'net.corda.core.internal.utilities.TestResourceWriter'
args new File(sourceSets.test.resources.srcDirs.first(), "zip").toString()
}

View File

@ -343,7 +343,11 @@ abstract class TransactionVerificationException(val txId: SecureHash, message: S
"You will need to manually install the CorDapp to whitelist it for use.") "You will need to manually install the CorDapp to whitelist it for use.")
@KeepForDJVM @KeepForDJVM
class UnsupportedHashTypeException(txId: SecureHash) : TransactionVerificationException(txId, "The transaction Id is defined by an unsupported hash type", null); class UnsupportedHashTypeException(txId: SecureHash) : TransactionVerificationException(txId, "The transaction Id is defined by an unsupported hash type", null)
@KeepForDJVM
class AttachmentTooBigException(txId: SecureHash) : TransactionVerificationException(
txId, "The transaction attachments are too large and exceed both max transaction size and the maximum allowed compression ratio", null)
/* /*
If you add a new class extending [TransactionVerificationException], please add a test in `TransactionVerificationExceptionSerializationTests` If you add a new class extending [TransactionVerificationException], please add a test in `TransactionVerificationExceptionSerializationTests`

View File

@ -25,8 +25,16 @@ interface DigestAlgorithm {
fun digest(bytes: ByteArray): ByteArray fun digest(bytes: ByteArray): ByteArray
/** /**
* Computes the digest of the [ByteArray] which is resistant to pre-image attacks. * Computes the digest of the [ByteArray] which is resistant to pre-image attacks. Only used to calculate the hash of the leaves of the
* ComponentGroup Merkle tree, starting from its serialized components.
* Default implementation provides double hashing, but can it be changed to single hashing or something else for better performance. * Default implementation provides double hashing, but can it be changed to single hashing or something else for better performance.
*/ */
fun preImageResistantDigest(bytes: ByteArray): ByteArray = digest(digest(bytes)) fun componentDigest(bytes: ByteArray): ByteArray = digest(digest(bytes))
/**
* Computes the digest of the [ByteArray] which is resistant to pre-image attacks. Only used to calculate the nonces for the leaves of
* the ComponentGroup Merkle tree.
* Default implementation provides double hashing, but can it be changed to single hashing or something else for better performance.
*/
fun nonceDigest(bytes: ByteArray): ByteArray = digest(digest(bytes))
} }

View File

@ -75,23 +75,20 @@ data class DigestService(val hashAlgorithm: String) {
val zeroHash: SecureHash val zeroHash: SecureHash
get() = SecureHash.zeroHashFor(hashAlgorithm) get() = SecureHash.zeroHashFor(hashAlgorithm)
// val privacySalt: PrivacySalt
// get() = PrivacySalt.createFor(hashAlgorithm)
/** /**
* Compute the hash of each serialised component so as to be used as Merkle tree leaf. The resultant output (leaf) is * Compute the hash of each serialised component so as to be used as Merkle tree leaf. The resultant output (leaf) is
* calculated using the service's hash algorithm, thus HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other * calculated using the service's hash algorithm, thus HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other
* algorithms loaded via JCA [MessageDigest], or DigestAlgorithm.preImageResistantDigest(nonce || serializedComponent) * algorithms loaded via JCA [MessageDigest], or DigestAlgorithm.componentDigest(nonce || serializedComponent)
* otherwise, where nonce is computed from [computeNonce]. * otherwise, where nonce is computed from [computeNonce].
*/ */
fun componentHash(opaqueBytes: OpaqueBytes, privacySalt: PrivacySalt, componentGroupIndex: Int, internalIndex: Int): SecureHash = fun componentHash(opaqueBytes: OpaqueBytes, privacySalt: PrivacySalt, componentGroupIndex: Int, internalIndex: Int): SecureHash =
componentHash(computeNonce(privacySalt, componentGroupIndex, internalIndex), opaqueBytes) componentHash(computeNonce(privacySalt, componentGroupIndex, internalIndex), opaqueBytes)
/** Return the HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest], /** Return the HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest],
* otherwise it's defined by DigestAlgorithm.preImageResistantDigest(nonce || serializedComponent). */ * otherwise it's defined by DigestAlgorithm.componentDigest(nonce || serializedComponent). */
fun componentHash(nonce: SecureHash, opaqueBytes: OpaqueBytes): SecureHash { fun componentHash(nonce: SecureHash, opaqueBytes: OpaqueBytes): SecureHash {
val data = nonce.bytes + opaqueBytes.bytes val data = nonce.bytes + opaqueBytes.bytes
return SecureHash.preImageResistantHashAs(hashAlgorithm, data) return SecureHash.componentHashAs(hashAlgorithm, data)
} }
/** /**
@ -109,11 +106,11 @@ data class DigestService(val hashAlgorithm: String) {
* @param groupIndex the fixed index (ordinal) of this component group. * @param groupIndex the fixed index (ordinal) of this component group.
* @param internalIndex the internal index of this object in its corresponding components list. * @param internalIndex the internal index of this object in its corresponding components list.
* @return HASH(HASH(privacySalt || groupIndex || internalIndex)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest], * @return HASH(HASH(privacySalt || groupIndex || internalIndex)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest],
* otherwise it's defined by DigestAlgorithm.preImageResistantDigest(privacySalt || groupIndex || internalIndex). * otherwise it's defined by DigestAlgorithm.nonceDigest(privacySalt || groupIndex || internalIndex).
*/ */
fun computeNonce(privacySalt: PrivacySalt, groupIndex: Int, internalIndex: Int) : SecureHash { fun computeNonce(privacySalt: PrivacySalt, groupIndex: Int, internalIndex: Int) : SecureHash {
val data = (privacySalt.bytes + ByteBuffer.allocate(NONCE_SIZE).putInt(groupIndex).putInt(internalIndex).array()) val data = (privacySalt.bytes + ByteBuffer.allocate(NONCE_SIZE).putInt(groupIndex).putInt(internalIndex).array())
return SecureHash.preImageResistantHashAs(hashAlgorithm, data) return SecureHash.nonceHashAs(hashAlgorithm, data)
} }
} }

View File

@ -37,19 +37,19 @@ sealed class MerkleTree {
require(algorithms.size == 1) { require(algorithms.size == 1) {
"Cannot build Merkle tree with multiple hash algorithms: $algorithms" "Cannot build Merkle tree with multiple hash algorithms: $algorithms"
} }
val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) } val leaves = padWithZeros(allLeavesHashes, nodeDigestService.hashAlgorithm == SecureHash.SHA2_256).map { Leaf(it) }
return buildMerkleTree(leaves, nodeDigestService) return buildMerkleTree(leaves, nodeDigestService)
} }
// If number of leaves in the tree is not a power of 2, we need to pad it with zero hashes. // If number of leaves in the tree is not a power of 2, we need to pad it with zero hashes.
private fun padWithZeros(allLeavesHashes: List<SecureHash>): List<SecureHash> { private fun padWithZeros(allLeavesHashes: List<SecureHash>, singleLeafWithoutPadding: Boolean): List<SecureHash> {
var n = allLeavesHashes.size var n = allLeavesHashes.size
if (isPow2(n)) return allLeavesHashes if (isPow2(n) && (n > 1 || singleLeafWithoutPadding)) return allLeavesHashes
val paddedHashes = ArrayList(allLeavesHashes) val paddedHashes = ArrayList(allLeavesHashes)
val zeroHash = SecureHash.zeroHashFor(paddedHashes[0].algorithm) val zeroHash = SecureHash.zeroHashFor(paddedHashes[0].algorithm)
while (!isPow2(n++)) { do {
paddedHashes.add(zeroHash) paddedHashes.add(zeroHash)
} } while (!isPow2(++n))
return paddedHashes return paddedHashes
} }

View File

@ -216,13 +216,31 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) {
* @param bytes The [ByteArray] to hash. * @param bytes The [ByteArray] to hash.
*/ */
@JvmStatic @JvmStatic
fun preImageResistantHashAs(algorithm: String, bytes: ByteArray): SecureHash { fun componentHashAs(algorithm: String, bytes: ByteArray): SecureHash {
return if (algorithm == SHA2_256) { return if (algorithm == SHA2_256) {
sha256Twice(bytes) sha256Twice(bytes)
} else { } else {
val digest = digestFor(algorithm).get() val digest = digestFor(algorithm).get()
val firstHash = digest.preImageResistantDigest(bytes) val hash = digest.componentDigest(bytes)
HASH(algorithm, digest.digest(firstHash)) HASH(algorithm, hash)
}
}
/**
* Computes the digest of the [ByteArray] which is resistant to pre-image attacks.
* It computes the hash of the hash for SHA2-256 and other algorithms loaded via JCA [MessageDigest].
* For custom algorithms the strategy can be modified via [DigestAlgorithm].
* @param algorithm The [MessageDigest] algorithm to use.
* @param bytes The [ByteArray] to hash.
*/
@JvmStatic
fun nonceHashAs(algorithm: String, bytes: ByteArray): SecureHash {
return if (algorithm == SHA2_256) {
sha256Twice(bytes)
} else {
val digest = digestFor(algorithm).get()
val hash = digest.nonceDigest(bytes)
HASH(algorithm, hash)
} }
} }

View File

@ -28,9 +28,7 @@ sealed class DigestAlgorithmFactory {
} }
private class CustomAlgorithmFactory(className: String) : DigestAlgorithmFactory() { private class CustomAlgorithmFactory(className: String) : DigestAlgorithmFactory() {
val constructor: Constructor<out DigestAlgorithm> = javaClass val constructor: Constructor<out DigestAlgorithm> = Class.forName(className, false, javaClass.classLoader)
.classLoader
.loadClass(className)
.asSubclass(DigestAlgorithm::class.java) .asSubclass(DigestAlgorithm::class.java)
.getConstructor() .getConstructor()
override val algorithm: String = constructor.newInstance().algorithm override val algorithm: String = constructor.newInstance().algorithm

View File

@ -11,9 +11,9 @@ import kotlin.concurrent.withLock
private val pooledScanMutex = ReentrantLock() private val pooledScanMutex = ReentrantLock()
/** /**
* Use this rather than the built in implementation of [scan] on [ClassGraph]. The built in implementation of [scan] creates * Use this rather than the built-in implementation of [ClassGraph.scan]. The built-in implementation creates
* a thread pool every time resulting in too many threads. This one uses a mutex to restrict concurrency. * a thread pool every time, resulting in too many threads. This one uses a mutex to restrict concurrency.
*/ */
fun ClassGraph.pooledScan(): ScanResult { fun ClassGraph.pooledScan(): ScanResult {
return pooledScanMutex.withLock { this@pooledScan.scan() } return pooledScanMutex.withLock(::scan)
} }

View File

@ -23,7 +23,7 @@ import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.a
fun <T: Any> createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class<T>, fun <T: Any> createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class<T>,
classVersionRange: IntRange? = null): Set<T> { classVersionRange: IntRange? = null): Set<T> {
return getNamesOfClassesImplementing(classloader, clazz, classVersionRange) return getNamesOfClassesImplementing(classloader, clazz, classVersionRange)
.map { classloader.loadClass(it).asSubclass(clazz) } .map { Class.forName(it, false, classloader).asSubclass(clazz) }
.mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() } .mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() }
} }

View File

@ -16,6 +16,7 @@ typealias Version = Int
* Attention: this value affects consensus, so it requires a minimum platform version bump in order to be changed. * Attention: this value affects consensus, so it requires a minimum platform version bump in order to be changed.
*/ */
const val MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT = 20 const val MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT = 20
private const val DJVM_SANDBOX_PREFIX = "sandbox."
private val log = loggerFor<AttachmentConstraint>() private val log = loggerFor<AttachmentConstraint>()
@ -29,10 +30,14 @@ val Attachment.contractVersion: Version get() = if (this is ContractAttachment)
val ContractState.requiredContractClassName: String? get() { val ContractState.requiredContractClassName: String? get() {
val annotation = javaClass.getAnnotation(BelongsToContract::class.java) val annotation = javaClass.getAnnotation(BelongsToContract::class.java)
if (annotation != null) { if (annotation != null) {
return annotation.value.java.typeName return annotation.value.java.typeName.removePrefix(DJVM_SANDBOX_PREFIX)
} }
val enclosingClass = javaClass.enclosingClass ?: return null val enclosingClass = javaClass.enclosingClass ?: return null
return if (Contract::class.java.isAssignableFrom(enclosingClass)) enclosingClass.typeName else null return if (Contract::class.java.isAssignableFrom(enclosingClass)) {
enclosingClass.typeName.removePrefix(DJVM_SANDBOX_PREFIX)
} else {
null
}
} }
/** /**

View File

@ -28,7 +28,7 @@ import java.util.jar.JarInputStream
// *Internal* Corda-specific utilities. // *Internal* Corda-specific utilities.
const val PLATFORM_VERSION = 9 const val PLATFORM_VERSION = 10
fun ServicesForResolution.ensureMinimumPlatformVersion(requiredMinPlatformVersion: Int, feature: String) { fun ServicesForResolution.ensureMinimumPlatformVersion(requiredMinPlatformVersion: Int, feature: String) {
checkMinimumPlatformVersion(networkParameters.minimumPlatformVersion, requiredMinPlatformVersion, feature) checkMinimumPlatformVersion(networkParameters.minimumPlatformVersion, requiredMinPlatformVersion, feature)

View File

@ -56,7 +56,9 @@ import java.security.cert.TrustAnchor
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.temporal.Temporal import java.time.temporal.Temporal
import java.util.* import java.util.Collections
import java.util.PrimitiveIterator
import java.util.Spliterator
import java.util.Spliterator.DISTINCT import java.util.Spliterator.DISTINCT
import java.util.Spliterator.IMMUTABLE import java.util.Spliterator.IMMUTABLE
import java.util.Spliterator.NONNULL import java.util.Spliterator.NONNULL
@ -64,6 +66,7 @@ import java.util.Spliterator.ORDERED
import java.util.Spliterator.SIZED import java.util.Spliterator.SIZED
import java.util.Spliterator.SORTED import java.util.Spliterator.SORTED
import java.util.Spliterator.SUBSIZED import java.util.Spliterator.SUBSIZED
import java.util.Spliterators
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.stream.Collectors import java.util.stream.Collectors

View File

@ -16,6 +16,6 @@ object PlatformVersionSwitches {
const val LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS = 5 const val LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS = 5
const val BATCH_DOWNLOAD_COUNTERPARTY_BACKCHAIN = 6 const val BATCH_DOWNLOAD_COUNTERPARTY_BACKCHAIN = 6
const val ENABLE_P2P_COMPRESSION = 7 const val ENABLE_P2P_COMPRESSION = 7
const val CERTIFICATE_ROTATION = 9
const val RESTRICTED_DATABASE_OPERATIONS = 7 const val RESTRICTED_DATABASE_OPERATIONS = 7
const val CERTIFICATE_ROTATION = 9
} }

View File

@ -54,7 +54,7 @@ fun combinedHash(components: Iterable<SecureHash>, digestService: DigestService)
components.forEach { components.forEach {
stream.write(it.bytes) stream.write(it.bytes)
} }
return digestService.hash(stream.toByteArray()); return digestService.hash(stream.toByteArray())
} }
/** /**
@ -114,14 +114,14 @@ fun deserialiseCommands(
componentGroups: List<ComponentGroup>, componentGroups: List<ComponentGroup>,
forceDeserialize: Boolean = false, forceDeserialize: Boolean = false,
factory: SerializationFactory = SerializationFactory.defaultFactory, factory: SerializationFactory = SerializationFactory.defaultFactory,
@Suppress("UNUSED_PARAMETER") context: SerializationContext = factory.defaultContext, context: SerializationContext = factory.defaultContext,
digestService: DigestService = DigestService.sha2_256 digestService: DigestService = DigestService.sha2_256
): List<Command<*>> { ): List<Command<*>> {
// TODO: we could avoid deserialising unrelated signers. // TODO: we could avoid deserialising unrelated signers.
// However, current approach ensures the transaction is not malformed // However, current approach ensures the transaction is not malformed
// and it will throw if any of the signers objects is not List of public keys). // and it will throw if any of the signers objects is not List of public keys).
val signersList: List<List<PublicKey>> = uncheckedCast(deserialiseComponentGroup(componentGroups, List::class, ComponentGroupEnum.SIGNERS_GROUP, forceDeserialize)) val signersList: List<List<PublicKey>> = uncheckedCast(deserialiseComponentGroup(componentGroups, List::class, ComponentGroupEnum.SIGNERS_GROUP, forceDeserialize, factory, context))
val commandDataList: List<CommandData> = deserialiseComponentGroup(componentGroups, CommandData::class, ComponentGroupEnum.COMMANDS_GROUP, forceDeserialize) val commandDataList: List<CommandData> = deserialiseComponentGroup(componentGroups, CommandData::class, ComponentGroupEnum.COMMANDS_GROUP, forceDeserialize, factory, context)
val group = componentGroups.firstOrNull { it.groupIndex == ComponentGroupEnum.COMMANDS_GROUP.ordinal } val group = componentGroups.firstOrNull { it.groupIndex == ComponentGroupEnum.COMMANDS_GROUP.ordinal }
return if (group is FilteredComponentGroup) { return if (group is FilteredComponentGroup) {
check(commandDataList.size <= signersList.size) { check(commandDataList.size <= signersList.size) {
@ -154,7 +154,9 @@ fun createComponentGroups(inputs: List<StateRef>,
timeWindow: TimeWindow?, timeWindow: TimeWindow?,
references: List<StateRef>, references: List<StateRef>,
networkParametersHash: SecureHash?): List<ComponentGroup> { networkParametersHash: SecureHash?): List<ComponentGroup> {
val serialize = { value: Any, _: Int -> value.serialize() } val serializationFactory = SerializationFactory.defaultFactory
val serializationContext = serializationFactory.defaultContext
val serialize = { value: Any, _: Int -> value.serialize(serializationFactory, serializationContext) }
val componentGroupMap: MutableList<ComponentGroup> = mutableListOf() val componentGroupMap: MutableList<ComponentGroup> = mutableListOf()
if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize))) if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize)))
if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize))) if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize)))
@ -177,7 +179,11 @@ fun createComponentGroups(inputs: List<StateRef>,
*/ */
@KeepForDJVM @KeepForDJVM
data class SerializedStateAndRef(val serializedState: SerializedBytes<TransactionState<ContractState>>, val ref: StateRef) { data class SerializedStateAndRef(val serializedState: SerializedBytes<TransactionState<ContractState>>, val ref: StateRef) {
fun toStateAndRef(): StateAndRef<ContractState> = StateAndRef(serializedState.deserialize(), ref) fun toStateAndRef(factory: SerializationFactory, context: SerializationContext) = StateAndRef(serializedState.deserialize(factory, context), ref)
fun toStateAndRef(): StateAndRef<ContractState> {
val factory = SerializationFactory.defaultFactory
return toStateAndRef(factory, factory.defaultContext)
}
} }
/** Check that network parameters hash on this transaction is the current hash for the network. */ /** Check that network parameters hash on this transaction is the current hash for the network. */

View File

@ -3,14 +3,40 @@ package net.corda.core.internal
import net.corda.core.DeleteForDJVM import net.corda.core.DeleteForDJVM
import net.corda.core.KeepForDJVM import net.corda.core.KeepForDJVM
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.* import net.corda.core.contracts.Attachment
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.ContractClassName
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.HashAttachmentConstraint
import net.corda.core.contracts.SignatureAttachmentConstraint
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.contracts.TransactionVerificationException.ConflictingAttachmentsRejection
import net.corda.core.contracts.TransactionVerificationException.ConstraintPropagationRejection
import net.corda.core.contracts.TransactionVerificationException.ContractConstraintRejection
import net.corda.core.contracts.TransactionVerificationException.ContractCreationError
import net.corda.core.contracts.TransactionVerificationException.ContractRejection
import net.corda.core.contracts.TransactionVerificationException.Direction
import net.corda.core.contracts.TransactionVerificationException.DuplicateAttachmentsRejection
import net.corda.core.contracts.TransactionVerificationException.InvalidConstraintRejection
import net.corda.core.contracts.TransactionVerificationException.MissingAttachmentRejection
import net.corda.core.contracts.TransactionVerificationException.NotaryChangeInWrongTransactionType
import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException
import net.corda.core.contracts.TransactionVerificationException.TransactionDuplicateEncumbranceException
import net.corda.core.contracts.TransactionVerificationException.TransactionMissingEncumbranceException
import net.corda.core.contracts.TransactionVerificationException.TransactionNonMatchingEncumbranceException
import net.corda.core.contracts.TransactionVerificationException.TransactionNotaryMismatchEncumbranceException
import net.corda.core.contracts.TransactionVerificationException.TransactionRequiredContractUnspecifiedException
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.internal.rules.StateContractValidationEnforcementRule import net.corda.core.internal.rules.StateContractValidationEnforcementRule
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.loggerFor
import java.util.function.Function import java.util.function.Function
import java.util.function.Supplier
@DeleteForDJVM @DeleteForDJVM
interface TransactionVerifierServiceInternal { interface TransactionVerifierServiceInternal {
@ -22,16 +48,54 @@ interface TransactionVerifierServiceInternal {
*/ */
fun LedgerTransaction.prepareVerify(attachments: List<Attachment>) = internalPrepareVerify(attachments) fun LedgerTransaction.prepareVerify(attachments: List<Attachment>) = internalPrepareVerify(attachments)
interface Verifier {
/**
* Placeholder function for the verification logic.
*/
fun verify()
}
// This class allows us unit-test transaction verification more easily.
abstract class AbstractVerifier(
protected val ltx: LedgerTransaction,
protected val transactionClassLoader: ClassLoader
) : Verifier {
protected abstract val transaction: Supplier<LedgerTransaction>
protected companion object {
@JvmField
val logger = loggerFor<Verifier>()
}
/**
* Check that the transaction is internally consistent, and then check that it is
* contract-valid by running verify() for each input and output state contract.
* If any contract fails to verify, the whole transaction is considered to be invalid.
*
* Note: Reference states are not verified.
*/
final override fun verify() {
try {
TransactionVerifier(transactionClassLoader).apply(transaction)
} catch (e: TransactionVerificationException) {
logger.error("Error validating transaction ${ltx.id}.", e.cause)
throw e
}
}
}
/** /**
* Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the * Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the
* wrong object instance. This class helps avoid that. * wrong object instance. This class helps avoid that.
*/ */
abstract class Verifier(val ltx: LedgerTransaction, protected val transactionClassLoader: ClassLoader) { @KeepForDJVM
private val inputStates: List<TransactionState<*>> = ltx.inputs.map { it.state } private class Validator(private val ltx: LedgerTransaction, private val transactionClassLoader: ClassLoader) {
private val allStates: List<TransactionState<*>> = inputStates + ltx.references.map { it.state } + ltx.outputs private val inputStates: List<TransactionState<*>> = ltx.inputs.map(StateAndRef<ContractState>::state)
private val allStates: List<TransactionState<*>> = inputStates + ltx.references.map(StateAndRef<ContractState>::state) + ltx.outputs
companion object { private companion object {
val logger = contextLogger() private val logger = loggerFor<Validator>()
} }
/** /**
@ -39,9 +103,9 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
* *
* It is a critical piece of the security of the platform. * It is a critical piece of the security of the platform.
* *
* @throws TransactionVerificationException * @throws net.corda.core.contracts.TransactionVerificationException
*/ */
fun verify() { fun validate() {
// checkNoNotaryChange and checkEncumbrancesValid are called here, and not in the c'tor, as they need access to the "outputs" // checkNoNotaryChange and checkEncumbrancesValid are called here, and not in the c'tor, as they need access to the "outputs"
// list, the contents of which need to be deserialized under the correct classloader. // list, the contents of which need to be deserialized under the correct classloader.
checkNoNotaryChange() checkNoNotaryChange()
@ -68,8 +132,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
// 4. Check that the [TransactionState] objects are correctly formed. // 4. Check that the [TransactionState] objects are correctly formed.
validateStatesAgainstContract() validateStatesAgainstContract()
// 5. Final step is to run the contract code. After the first 4 steps we are now sure that we are running the correct code. // 5. Final step will be to run the contract code.
verifyContracts()
} }
private fun checkTransactionWithTimeWindowIsNotarised() { private fun checkTransactionWithTimeWindowIsNotarised() {
@ -81,11 +144,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
* It makes sure there is one and only one. * It makes sure there is one and only one.
* This is an important piece of the security of transactions. * This is an important piece of the security of transactions.
*/ */
@Suppress("ThrowsCount")
private fun getUniqueContractAttachmentsByContract(): Map<ContractClassName, ContractAttachment> { private fun getUniqueContractAttachmentsByContract(): Map<ContractClassName, ContractAttachment> {
val contractClasses = allStates.map { it.contract }.toSet() val contractClasses = allStates.mapTo(LinkedHashSet(), TransactionState<*>::contract)
// Check that there are no duplicate attachments added. // Check that there are no duplicate attachments added.
if (ltx.attachments.size != ltx.attachments.toSet().size) throw TransactionVerificationException.DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first()) if (ltx.attachments.size != ltx.attachments.toSet().size) throw DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first())
// For each attachment this finds all the relevant state contracts that it provides. // For each attachment this finds all the relevant state contracts that it provides.
// And then maps them to the attachment. // And then maps them to the attachment.
@ -103,12 +167,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
.groupBy { it.first } // Group by contract. .groupBy { it.first } // Group by contract.
.filter { (_, attachments) -> attachments.size > 1 } // And only keep contracts that are in multiple attachments. It's guaranteed that attachments were unique by a previous check. .filter { (_, attachments) -> attachments.size > 1 } // And only keep contracts that are in multiple attachments. It's guaranteed that attachments were unique by a previous check.
.keys.firstOrNull() // keep the first one - if any - to throw a meaningful exception. .keys.firstOrNull() // keep the first one - if any - to throw a meaningful exception.
if (contractWithMultipleAttachments != null) throw TransactionVerificationException.ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments) if (contractWithMultipleAttachments != null) throw ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments)
val result = contractAttachmentsPerContract.toMap() val result = contractAttachmentsPerContract.toMap()
// Check that there is an attachment for each contract. // Check that there is an attachment for each contract.
if (result.keys != contractClasses) throw TransactionVerificationException.MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first()) if (result.keys != contractClasses) throw MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first())
return result return result
} }
@ -124,7 +188,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
if (ltx.notary != null && (ltx.inputs.isNotEmpty() || ltx.references.isNotEmpty())) { if (ltx.notary != null && (ltx.inputs.isNotEmpty() || ltx.references.isNotEmpty())) {
ltx.outputs.forEach { ltx.outputs.forEach {
if (it.notary != ltx.notary) { if (it.notary != ltx.notary) {
throw TransactionVerificationException.NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary) throw NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary)
} }
} }
} }
@ -156,10 +220,10 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance
} }
if (!encumbranceStateExists) { if (!encumbranceStateExists) {
throw TransactionVerificationException.TransactionMissingEncumbranceException( throw TransactionMissingEncumbranceException(
ltx.id, ltx.id,
state.encumbrance!!, state.encumbrance!!,
TransactionVerificationException.Direction.INPUT Direction.INPUT
) )
} }
} }
@ -185,6 +249,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
// b -> c and c -> b // b -> c and c -> b
// c -> a b -> a // c -> a b -> a
// and form a full cycle, meaning that the bi-directionality property is satisfied. // and form a full cycle, meaning that the bi-directionality property is satisfied.
@Suppress("ThrowsCount")
private fun checkBidirectionalOutputEncumbrances(statesAndEncumbrance: List<Pair<Int, Int>>) { private fun checkBidirectionalOutputEncumbrances(statesAndEncumbrance: List<Pair<Int, Int>>) {
// [Set] of "from" (encumbered states). // [Set] of "from" (encumbered states).
val encumberedSet = mutableSetOf<Int>() val encumberedSet = mutableSetOf<Int>()
@ -194,15 +259,15 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
statesAndEncumbrance.forEach { (statePosition, encumbrance) -> statesAndEncumbrance.forEach { (statePosition, encumbrance) ->
// Check it does not refer to itself. // Check it does not refer to itself.
if (statePosition == encumbrance || encumbrance >= ltx.outputs.size) { if (statePosition == encumbrance || encumbrance >= ltx.outputs.size) {
throw TransactionVerificationException.TransactionMissingEncumbranceException( throw TransactionMissingEncumbranceException(
ltx.id, ltx.id,
encumbrance, encumbrance,
TransactionVerificationException.Direction.OUTPUT Direction.OUTPUT
) )
} else { } else {
encumberedSet.add(statePosition) // Guaranteed to have unique elements. encumberedSet.add(statePosition) // Guaranteed to have unique elements.
if (!encumbranceSet.add(encumbrance)) { if (!encumbranceSet.add(encumbrance)) {
throw TransactionVerificationException.TransactionDuplicateEncumbranceException(ltx.id, encumbrance) throw TransactionDuplicateEncumbranceException(ltx.id, encumbrance)
} }
} }
} }
@ -211,7 +276,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
val symmetricDifference = (encumberedSet union encumbranceSet).subtract(encumberedSet intersect encumbranceSet) val symmetricDifference = (encumberedSet union encumbranceSet).subtract(encumberedSet intersect encumbranceSet)
if (symmetricDifference.isNotEmpty()) { if (symmetricDifference.isNotEmpty()) {
// At least one encumbered state is not in the [encumbranceSet] and vice versa. // At least one encumbered state is not in the [encumbranceSet] and vice versa.
throw TransactionVerificationException.TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference) throw TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference)
} }
} }
@ -235,7 +300,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
if (indicesAlreadyChecked.add(index)) { if (indicesAlreadyChecked.add(index)) {
val encumbranceIndex = ltx.outputs[index].encumbrance!! val encumbranceIndex = ltx.outputs[index].encumbrance!!
if (ltx.outputs[index].notary != ltx.outputs[encumbranceIndex].notary) { if (ltx.outputs[index].notary != ltx.outputs[encumbranceIndex].notary) {
throw TransactionVerificationException.TransactionNotaryMismatchEncumbranceException( throw TransactionNotaryMismatchEncumbranceException(
ltx.id, ltx.id,
index, index,
encumbranceIndex, encumbranceIndex,
@ -263,7 +328,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
val shouldEnforce = StateContractValidationEnforcementRule.shouldEnforce(state.data) val shouldEnforce = StateContractValidationEnforcementRule.shouldEnforce(state.data)
val requiredContractClassName = state.data.requiredContractClassName val requiredContractClassName = state.data.requiredContractClassName
?: if (shouldEnforce) throw TransactionVerificationException.TransactionRequiredContractUnspecifiedException(ltx.id, state) else return ?: if (shouldEnforce) throw TransactionRequiredContractUnspecifiedException(ltx.id, state) else return
if (state.contract != requiredContractClassName) if (state.contract != requiredContractClassName)
if (shouldEnforce) { if (shouldEnforce) {
@ -281,6 +346,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
* - Constraints should be one of the valid supported ones. * - Constraints should be one of the valid supported ones.
* - Constraints should propagate correctly if not marked otherwise (in that case it is the responsibility of the contract to ensure that the output states are created properly). * - Constraints should propagate correctly if not marked otherwise (in that case it is the responsibility of the contract to ensure that the output states are created properly).
*/ */
@Suppress("NestedBlockDepth")
private fun verifyConstraintsValidity(contractAttachmentsByContract: Map<ContractClassName, ContractAttachment>) { private fun verifyConstraintsValidity(contractAttachmentsByContract: Map<ContractClassName, ContractAttachment>) {
// First check that the constraints are valid. // First check that the constraints are valid.
@ -310,7 +376,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
outputConstraints.forEach { outputConstraint -> outputConstraints.forEach { outputConstraint ->
inputConstraints.forEach { inputConstraint -> inputConstraints.forEach { inputConstraint ->
if (!(outputConstraint.canBeTransitionedFrom(inputConstraint, contractAttachment))) { if (!(outputConstraint.canBeTransitionedFrom(inputConstraint, contractAttachment))) {
throw TransactionVerificationException.ConstraintPropagationRejection( throw ConstraintPropagationRejection(
ltx.id, ltx.id,
contractClassName, contractClassName,
inputConstraint, inputConstraint,
@ -331,7 +397,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
@Suppress("NestedBlockDepth", "MagicNumber") @Suppress("NestedBlockDepth", "MagicNumber")
private fun verifyConstraints(contractAttachmentsByContract: Map<ContractClassName, ContractAttachment>) { private fun verifyConstraints(contractAttachmentsByContract: Map<ContractClassName, ContractAttachment>) {
// For each contract/constraint pair check that the relevant attachment is valid. // For each contract/constraint pair check that the relevant attachment is valid.
allStates.map { it.contract to it.constraint }.toSet().forEach { (contract, constraint) -> allStates.mapTo(LinkedHashSet()) { it.contract to it.constraint }.forEach { (contract, constraint) ->
if (constraint is SignatureAttachmentConstraint) { if (constraint is SignatureAttachmentConstraint) {
/** /**
* Support for signature constraints has been added on * Support for signature constraints has been added on
@ -346,9 +412,9 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
"Signature constraints" "Signature constraints"
) )
val constraintKey = constraint.key val constraintKey = constraint.key
if (ltx.networkParameters?.minimumPlatformVersion ?: 1 >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) { if ((ltx.networkParameters?.minimumPlatformVersion ?: 1) >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) {
if (constraintKey is CompositeKey && constraintKey.leafKeys.size > MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT) { if (constraintKey is CompositeKey && constraintKey.leafKeys.size > MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT) {
throw TransactionVerificationException.InvalidConstraintRejection(ltx.id, contract, throw InvalidConstraintRejection(ltx.id, contract,
"Signature constraint contains composite key with ${constraintKey.leafKeys.size} leaf keys, " + "Signature constraint contains composite key with ${constraintKey.leafKeys.size} leaf keys, " +
"which is more than the maximum allowed number of keys " + "which is more than the maximum allowed number of keys " +
"($MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT).") "($MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT).")
@ -364,39 +430,18 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla
if (HashAttachmentConstraint.disableHashConstraints && constraint is HashAttachmentConstraint) if (HashAttachmentConstraint.disableHashConstraints && constraint is HashAttachmentConstraint)
logger.warnOnce("Skipping hash constraints verification.") logger.warnOnce("Skipping hash constraints verification.")
else if (!constraint.isSatisfiedBy(constraintAttachment)) else if (!constraint.isSatisfiedBy(constraintAttachment))
throw TransactionVerificationException.ContractConstraintRejection(ltx.id, contract) throw ContractConstraintRejection(ltx.id, contract)
}
}
/**
* Placeholder function for the contract verification logic.
*/
abstract fun verifyContracts()
}
class BasicVerifier(ltx: LedgerTransaction, transactionClassLoader: ClassLoader) : Verifier(ltx, transactionClassLoader) {
/**
* Check the transaction is contract-valid by running the verify() for each input and output state contract.
* If any contract fails to verify, the whole transaction is considered to be invalid.
*
* Note: Reference states are not verified.
*/
override fun verifyContracts() {
try {
ContractVerifier(transactionClassLoader).apply(ltx)
} catch (e: TransactionVerificationException.ContractRejection) {
logger.error("Error validating transaction ${ltx.id}.", e.cause)
throw e
} }
} }
} }
/** /**
* Verify all of the contracts on the given [LedgerTransaction]. * Verify the given [LedgerTransaction]. This includes validating
* its contents, as well as executing all of its smart contracts.
*/ */
@Suppress("TooGenericExceptionCaught") @Suppress("TooGenericExceptionCaught")
@KeepForDJVM @KeepForDJVM
class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function<LedgerTransaction, Unit> { class TransactionVerifier(private val transactionClassLoader: ClassLoader) : Function<Supplier<LedgerTransaction>, Unit> {
// This constructor is used inside the DJVM's sandbox. // This constructor is used inside the DJVM's sandbox.
@Suppress("unused") @Suppress("unused")
constructor() : this(ClassLoader.getSystemClassLoader()) constructor() : this(ClassLoader.getSystemClassLoader())
@ -406,34 +451,62 @@ class ContractVerifier(private val transactionClassLoader: ClassLoader) : Functi
return try { return try {
Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java) Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java)
} catch (e: Exception) { } catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(id, contractClassName, e) throw ContractCreationError(id, contractClassName, e)
} }
} }
override fun apply(ltx: LedgerTransaction) { private fun generateContracts(ltx: LedgerTransaction): List<Contract> {
val contractClassNames = (ltx.inputs.map(StateAndRef<ContractState>::state) + ltx.outputs) return (ltx.inputs.map(StateAndRef<ContractState>::state) + ltx.outputs)
.mapTo(LinkedHashSet(), TransactionState<*>::contract) .mapTo(LinkedHashSet(), TransactionState<*>::contract)
.map { contractClassName ->
contractClassNames.associateBy( createContractClass(ltx.id, contractClassName)
{ it }, { createContractClass(ltx.id, it) } }.map { contractClass ->
).map { (contractClassName, contractClass) ->
try { try {
/** /**
* This function must execute with the DJVM's sandbox, which does not * This function must execute within the DJVM's sandbox, which does not
* permit user code to invoke [java.lang.Class.getDeclaredConstructor]. * permit user code to invoke [java.lang.reflect.Constructor.newInstance].
* (This would be fixable now, provided the constructor is public.)
* *
* [Class.newInstance] is deprecated as of Java 9. * [Class.newInstance] is deprecated as of Java 9.
*/ */
@Suppress("deprecation") @Suppress("deprecation")
contractClass.newInstance() contractClass.newInstance()
} catch (e: Exception) { } catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e) throw ContractCreationError(ltx.id, contractClass.name, e)
} }
}
}
private fun validateTransaction(ltx: LedgerTransaction) {
Validator(ltx, transactionClassLoader).validate()
}
override fun apply(transactionFactory: Supplier<LedgerTransaction>) {
var firstLtx: LedgerTransaction? = null
transactionFactory.get().let { ltx ->
firstLtx = ltx
/**
* Check that this transaction is correctly formed.
* We only need to run these checks once.
*/
validateTransaction(ltx)
/**
* Generate the list of unique contracts
* within this transaction.
*/
generateContracts(ltx)
}.forEach { contract -> }.forEach { contract ->
val ltx = firstLtx ?: transactionFactory.get()
firstLtx = null
try { try {
// Final step is to run the contract code. Having validated the
// transaction, we are now sure that we are running the correct code.
contract.verify(ltx) contract.verify(ltx)
} catch (e: Exception) { } catch (e: Exception) {
throw TransactionVerificationException.ContractRejection(ltx.id, contract, e) throw ContractRejection(ltx.id, contract, e)
} }
} }
} }

View File

@ -14,6 +14,14 @@ abstract class NotaryService : SingletonSerializeAsToken() {
abstract val services: ServiceHub abstract val services: ServiceHub
abstract val notaryIdentityKey: PublicKey abstract val notaryIdentityKey: PublicKey
/**
* Mapping between @InitiatingFlow classes and factory methods that produce responder flows.
* Can be overridden in case of advanced notary service that serves both custom and standard flows.
*/
open val initiatingFlows = mapOf(
NotaryFlow.Client::class to ::createServiceFlow
)
/** /**
* Interfaces for the request and result formats of queries supported by notary services. To * Interfaces for the request and result formats of queries supported by notary services. To
* implement a new query, you must: * implement a new query, you must:

View File

@ -28,7 +28,11 @@ import java.time.Duration
* @param etaThreshold If the ETA for processing the request, according to the service, is greater than this, notify the client. * @param etaThreshold If the ETA for processing the request, according to the service, is greater than this, notify the client.
*/ */
// See AbstractStateReplacementFlow.Acceptor for why it's Void? // See AbstractStateReplacementFlow.Acceptor for why it's Void?
abstract class NotaryServiceFlow(val otherSideSession: FlowSession, val service: SinglePartyNotaryService, private val etaThreshold: Duration) : FlowLogic<Void?>() { abstract class NotaryServiceFlow(
val otherSideSession: FlowSession,
val service: SinglePartyNotaryService,
private val etaThreshold: Duration
) : FlowLogic<Void?>() {
companion object { companion object {
// TODO: Determine an appropriate limit and also enforce in the network parameters and the transaction builder. // TODO: Determine an appropriate limit and also enforce in the network parameters and the transaction builder.
private const val maxAllowedInputsAndReferences = 10_000 private const val maxAllowedInputsAndReferences = 10_000

View File

@ -0,0 +1,63 @@
package net.corda.core.internal.utilities
import java.io.FilterInputStream
import java.io.InputStream
import java.util.zip.ZipInputStream
object ZipBombDetector {
private class CounterInputStream(source : InputStream) : FilterInputStream(source) {
private var byteCount : Long = 0
val count : Long
get() = byteCount
override fun read(): Int {
return super.read().also { byte ->
if(byte >= 0) byteCount += 1
}
}
override fun read(b: ByteArray): Int {
return super.read(b).also { bytesRead ->
if(bytesRead > 0) byteCount += bytesRead
}
}
override fun read(b: ByteArray, off: Int, len: Int): Int {
return super.read(b, off, len).also { bytesRead ->
if(bytesRead > 0) byteCount += bytesRead
}
}
}
/**
* Check if a zip file is a potential malicious zip bomb
* @param source the zip archive file content
* @param maxUncompressedSize the maximum allowable uncompressed archive size
* @param maxCompressionRatio the maximum allowable compression ratio
* @return true if the zip file total uncompressed size exceeds [maxUncompressedSize] and the
* average entry compression ratio is larger than [maxCompressionRatio], false otherwise
*/
@Suppress("NestedBlockDepth")
fun scanZip(source : InputStream, maxUncompressedSize : Long, maxCompressionRatio : Float = 10.0f) : Boolean {
val counterInputStream = CounterInputStream(source)
var uncompressedByteCount : Long = 0
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
ZipInputStream(counterInputStream).use { zipInputStream ->
while(true) {
zipInputStream.nextEntry ?: break
while(true) {
val read = zipInputStream.read(buffer)
if(read <= 0) break
uncompressedByteCount += read
if(uncompressedByteCount > maxUncompressedSize &&
uncompressedByteCount.toFloat() / counterInputStream.count.toFloat() > maxCompressionRatio) {
return true
}
}
}
}
return false
}
}

View File

@ -13,6 +13,8 @@ import net.corda.core.utilities.days
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.Collections.unmodifiableList
import java.util.Collections.unmodifiableMap
// DOCSTART 1 // DOCSTART 1
/** /**
@ -166,6 +168,38 @@ data class NetworkParameters(
epoch=$epoch epoch=$epoch
}""" }"""
} }
fun toImmutable(): NetworkParameters {
return NetworkParameters(
minimumPlatformVersion = minimumPlatformVersion,
notaries = unmodifiable(notaries),
maxMessageSize = maxMessageSize,
maxTransactionSize = maxTransactionSize,
modifiedTime = modifiedTime,
epoch = epoch,
whitelistedContractImplementations = unmodifiable(whitelistedContractImplementations) { entry ->
unmodifiableList(entry.value)
},
eventHorizon = eventHorizon,
packageOwnership = unmodifiable(packageOwnership)
)
}
}
private fun <T> unmodifiable(list: List<T>): List<T> {
return if (list.isEmpty()) {
emptyList()
} else {
unmodifiableList(list)
}
}
private inline fun <K, V> unmodifiable(map: Map<K, V>, transform: (Map.Entry<K, V>) -> V = Map.Entry<K, V>::value): Map<K, V> {
return if (map.isEmpty()) {
emptyMap()
} else {
unmodifiableMap(map.mapValues(transform))
}
} }
/** /**

View File

@ -64,7 +64,7 @@ interface ServicesForResolution {
/** /**
* Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState].
* *
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction.
*/ */
// TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load
// as the existing transaction store will become encrypted at some point // as the existing transaction store will become encrypted at some point

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.core.node.services package net.corda.core.node.services
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
@ -197,8 +199,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
* 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL]. * 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL].
* 5) Other results as a [List] of any type (eg. aggregate function results with/without group by). * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by).
* *
* Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata * Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty).
* results will be empty).
*/ */
@CordaSerializable @CordaSerializable
data class Page<out T : ContractState>(val states: List<StateAndRef<T>>, data class Page<out T : ContractState>(val states: List<StateAndRef<T>>,
@ -213,11 +214,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
val contractStateClassName: String, val contractStateClassName: String,
val recordedTime: Instant, val recordedTime: Instant,
val consumedTime: Instant?, val consumedTime: Instant?,
val status: Vault.StateStatus, val status: StateStatus,
val notary: AbstractParty?, val notary: AbstractParty?,
val lockId: String?, val lockId: String?,
val lockUpdateTime: Instant?, val lockUpdateTime: Instant?,
val relevancyStatus: Vault.RelevancyStatus? = null, val relevancyStatus: RelevancyStatus? = null,
val constraintInfo: ConstraintInfo? = null val constraintInfo: ConstraintInfo? = null
) { ) {
fun copy( fun copy(
@ -225,7 +226,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime lockUpdateTime: Instant? = this.lockUpdateTime
@ -237,11 +238,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime, lockUpdateTime: Instant? = this.lockUpdateTime,
relevancyStatus: Vault.RelevancyStatus? relevancyStatus: RelevancyStatus?
): StateMetadata { ): StateMetadata {
return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint)) return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint))
} }
@ -249,9 +250,9 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
companion object { companion object {
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet()) val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet())
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet()) val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet())
} }
} }
@ -302,7 +303,7 @@ interface VaultService {
fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> { fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> {
val query = QueryCriteria.VaultQueryCriteria( val query = QueryCriteria.VaultQueryCriteria(
stateRefs = listOf(ref), stateRefs = listOf(ref),
status = Vault.StateStatus.CONSUMED status = StateStatus.CONSUMED
) )
val result = trackBy<ContractState>(query) val result = trackBy<ContractState>(query)
val snapshot = result.snapshot.states val snapshot = result.snapshot.states
@ -358,8 +359,8 @@ interface VaultService {
/** /**
* Helper function to determine spendable states and soft locking them. * Helper function to determine spendable states and soft locking them.
* Currently performance will be worse than for the hand optimised version in * Currently performance will be worse than for the hand optimised version in
* [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState] * [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic
* and [FungibleAsset] states. * and can operate with custom [FungibleState] and [FungibleAsset] states.
* @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states. * @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states.
* @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the
* [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the * [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the

View File

@ -0,0 +1,35 @@
package net.corda.core.serialization
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/***
* Implement this interface to add your own Serialization Scheme. This is an experimental feature. All methods in this class MUST be
* thread safe i.e. methods from the same instance of this class can be called in different threads simultaneously.
*/
interface CustomSerializationScheme {
/**
* This method must return an id used to uniquely identify the Scheme. This should be unique within a network as serialized data might
* be sent over the wire.
*/
fun getSchemeId(): Int
/**
* This method must deserialize the data stored [bytes] into an instance of [T].
*
* @param bytes the serialized data.
* @param clazz the class to instantiate.
* @param context used to pass information about how the object should be deserialized.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T
/**
* This method must be able to serialize any object [T] into a ByteSequence.
*
* @param obj the object to be serialized.
* @param context used to pass information about how the object should be serialized.
*/
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence
}

View File

@ -13,6 +13,10 @@ import net.corda.core.utilities.sequence
import java.io.NotSerializableException import java.io.NotSerializableException
import java.sql.Blob import java.sql.Blob
const val DESERIALIZATION_CACHE_PROPERTY = "DESERIALIZATION_CACHE"
const val AMQP_ENVELOPE_CACHE_PROPERTY = "AMQP_ENVELOPE_CACHE"
const val AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY = 256
data class ObjectWithCompatibleContext<out T : Any>(val obj: T, val context: SerializationContext) data class ObjectWithCompatibleContext<out T : Any>(val obj: T, val context: SerializationContext)
/** /**
@ -65,12 +69,16 @@ abstract class SerializationFactory {
* Change the current context inside the block to that supplied. * Change the current context inside the block to that supplied.
*/ */
fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T { fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T {
return if (context == null) {
block()
} else {
val priorContext = _currentContext.get() val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context) _currentContext.set(context)
try { try {
return block() block()
} finally { } finally {
if (context != null) _currentContext.set(priorContext) _currentContext.set(priorContext)
}
} }
} }
@ -134,7 +142,7 @@ interface SerializationContext {
*/ */
val encodingWhitelist: EncodingWhitelist val encodingWhitelist: EncodingWhitelist
/** /**
* A map of any addition properties specific to the particular use case. * A map of any additional properties specific to the particular use case.
*/ */
val properties: Map<Any, Any> val properties: Map<Any, Any>
/** /**
@ -178,6 +186,11 @@ interface SerializationContext {
*/ */
fun withProperty(property: Any, value: Any): SerializationContext fun withProperty(property: Any, value: Any): SerializationContext
/**
* Helper method to return a new context based on this context with the extra properties added.
*/
fun withProperties(extraProperties: Map<Any, Any>): SerializationContext
/** /**
* Helper method to return a new context based on this context with object references disabled. * Helper method to return a new context based on this context with object references disabled.
*/ */

View File

@ -0,0 +1,30 @@
package net.corda.core.serialization
import net.corda.core.DoNotImplement
/**
* This is used to pass information into [CustomSerializationScheme] about how the object should be (de)serialized.
* This context can change depending on the specific circumstances in the node when (de)serialization occurs.
*/
@DoNotImplement
interface SerializationSchemeContext {
/**
* The class loader to use for deserialization. This is guaranteed to be able to load all the required classes passed into
* [CustomSerializationScheme.deserialize].
*/
val deserializationClassLoader: ClassLoader
/**
* A whitelist that contains (mostly for security purposes) which classes are authorised to be deserialized.
* A secure implementation will not instantiate any object which is not either whitelisted or annotated with [CordaSerializable] when
* deserializing. To catch classes missing from the whitelist as early as possible it is HIGHLY recommended to also check this
* whitelist when serializing (as well as deserializing) objects.
*/
val whitelist: ClassWhitelist
/**
* A map of any additional properties specific to the particular use case. If these properties are set via
* [toWireTransaction][net.corda.core.transactions.TransactionBuilder.toWireTransaction] then they might not be available when
* deserializing. If the properties are required when deserializing, they can be added into the blob when serializing and read back
* when deserializing.
*/
val properties: Map<Any, Any>
}

View File

@ -9,21 +9,46 @@ import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException
import net.corda.core.contracts.TransactionVerificationException.PackageOwnershipException import net.corda.core.contracts.TransactionVerificationException.PackageOwnershipException
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256 import net.corda.core.internal.JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION
import net.corda.core.internal.* import net.corda.core.internal.JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION
import net.corda.core.internal.JarSignatureCollector
import net.corda.core.internal.NamedCacheFactory
import net.corda.core.internal.PlatformVersionSwitches
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.cordapp.targetPlatformVersion import net.corda.core.internal.cordapp.targetPlatformVersion
import net.corda.core.internal.createInstancesOfClassesImplementing
import net.corda.core.internal.createSimpleCache
import net.corda.core.internal.toSynchronised
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.serialization.* import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY
import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY
import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.toUrl import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.toUrl
import net.corda.core.serialization.withWhitelist
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import java.io.ByteArrayOutputStream import net.corda.core.utilities.loggerFor
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.lang.ref.ReferenceQueue
import java.lang.ref.WeakReference import java.lang.ref.WeakReference
import java.net.* import java.net.URL
import java.net.URLClassLoader
import java.net.URLConnection
import java.net.URLStreamHandler
import java.net.URLStreamHandlerFactory
import java.security.MessageDigest
import java.security.Permission import java.security.Permission
import java.util.* import java.util.Locale
import java.util.ServiceLoader
import java.util.WeakHashMap
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Function import java.util.function.Function
/** /**
@ -51,12 +76,15 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
init { init {
// Apply our own URLStreamHandlerFactory to resolve attachments // Apply our own URLStreamHandlerFactory to resolve attachments
setOrDecorateURLStreamHandlerFactory() setOrDecorateURLStreamHandlerFactory()
// Allow AttachmentsClassLoader to be used concurrently.
registerAsParallelCapable()
} }
// Jolokia and Json-simple are dependencies that were bundled by mistake within contract jars. // Jolokia and Json-simple are dependencies that were bundled by mistake within contract jars.
// In the AttachmentsClassLoader we just block any class in those 2 packages. // In the AttachmentsClassLoader we just block any class in those 2 packages.
private val ignoreDirectories = listOf("org/jolokia/", "org/json/simple/") private val ignoreDirectories = listOf("org/jolokia/", "org/json/simple/")
private val ignorePackages = ignoreDirectories.map { it.replace("/", ".") } private val ignorePackages = ignoreDirectories.map { it.replace('/', '.') }
/** /**
* Apply our custom factory either directly, if `URL.setURLStreamHandlerFactory` has not been called yet, * Apply our custom factory either directly, if `URL.setURLStreamHandlerFactory` has not been called yet,
@ -128,6 +156,20 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
checkAttachments(attachments) checkAttachments(attachments)
} }
private class AttachmentHashContext(
val txId: SecureHash,
val buffer: ByteArray = ByteArray(DEFAULT_BUFFER_SIZE))
private fun hash(inputStream : InputStream, ctx : AttachmentHashContext) : SecureHash.SHA256 {
val md = MessageDigest.getInstance(SecureHash.SHA2_256)
while(true) {
val read = inputStream.read(ctx.buffer)
if(read <= 0) break
md.update(ctx.buffer, 0, read)
}
return SecureHash.SHA256(md.digest())
}
private fun isZipOrJar(attachment: Attachment) = attachment.openAsJAR().use { jar -> private fun isZipOrJar(attachment: Attachment) = attachment.openAsJAR().use { jar ->
jar.nextEntry != null jar.nextEntry != null
} }
@ -146,10 +188,10 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
// TODO - investigate potential exploits. // TODO - investigate potential exploits.
private fun shouldCheckForNoOverlap(path: String, targetPlatformVersion: Int): Boolean { private fun shouldCheckForNoOverlap(path: String, targetPlatformVersion: Int): Boolean {
require(path.toLowerCase() == path) require(path.toLowerCase() == path)
require(!path.contains("\\")) require(!path.contains('\\'))
return when { return when {
path.endsWith("/") -> false // Directories (packages) can overlap. path.endsWith('/') -> false // Directories (packages) can overlap.
targetPlatformVersion < PlatformVersionSwitches.IGNORE_JOLOKIA_JSON_SIMPLE_IN_CORDAPPS && targetPlatformVersion < PlatformVersionSwitches.IGNORE_JOLOKIA_JSON_SIMPLE_IN_CORDAPPS &&
ignoreDirectories.any { path.startsWith(it) } -> false // Ignore jolokia and json-simple for old cordapps. ignoreDirectories.any { path.startsWith(it) } -> false // Ignore jolokia and json-simple for old cordapps.
path.endsWith(".class") -> true // All class files need to be unique. path.endsWith(".class") -> true // All class files need to be unique.
@ -160,6 +202,7 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
} }
} }
@Suppress("ThrowsCount", "ComplexMethod", "NestedBlockDepth")
private fun checkAttachments(attachments: List<Attachment>) { private fun checkAttachments(attachments: List<Attachment>) {
require(attachments.isNotEmpty()) { "attachments list is empty" } require(attachments.isNotEmpty()) { "attachments list is empty" }
@ -188,7 +231,8 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
// attacks on externally connected systems that only consider type names, we allow people to formally // attacks on externally connected systems that only consider type names, we allow people to formally
// claim their parts of the Java package namespace via registration with the zone operator. // claim their parts of the Java package namespace via registration with the zone operator.
val classLoaderEntries = mutableMapOf<String, SecureHash.SHA256>() val classLoaderEntries = mutableMapOf<String, SecureHash>()
val ctx = AttachmentHashContext(sampleTxId)
for (attachment in attachments) { for (attachment in attachments) {
// We may have been given an attachment loaded from the database in which case, important info like // We may have been given an attachment loaded from the database in which case, important info like
// signers is already calculated. // signers is already calculated.
@ -206,10 +250,12 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
// signed by the owners of the packages, even if it's not. We'd eventually discover that fact // signed by the owners of the packages, even if it's not. We'd eventually discover that fact
// when trying to read the class file to use it, but if we'd made any decisions based on // when trying to read the class file to use it, but if we'd made any decisions based on
// perceived correctness of the signatures or package ownership already, that would be too late. // perceived correctness of the signatures or package ownership already, that would be too late.
attachment.openAsJAR().use { JarSignatureCollector.collectSigners(it) } attachment.openAsJAR().use(JarSignatureCollector::collectSigners)
} }
// Now open it again to compute the overlap and package ownership data. // Now open it again to compute the overlap and package ownership data.
attachment.openAsJAR().use { jar -> attachment.openAsJAR().use { jar ->
val targetPlatformVersion = jar.manifest?.targetPlatformVersion ?: 1 val targetPlatformVersion = jar.manifest?.targetPlatformVersion ?: 1
while (true) { while (true) {
val entry = jar.nextJarEntry ?: break val entry = jar.nextJarEntry ?: break
@ -250,13 +296,9 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
if (!shouldCheckForNoOverlap(path, targetPlatformVersion)) continue if (!shouldCheckForNoOverlap(path, targetPlatformVersion)) continue
// This calculates the hash of the current entry because the JarInputStream returns only the current entry. // This calculates the hash of the current entry because the JarInputStream returns only the current entry.
fun entryHash() = ByteArrayOutputStream().use { val currentHash = hash(jar, ctx)
jar.copyTo(it)
it.toByteArray()
}.sha256()
// If 2 entries are identical, it means the same file is present in both attachments, so that is ok. // If 2 entries are identical, it means the same file is present in both attachments, so that is ok.
val currentHash = entryHash()
val previousFileHash = classLoaderEntries[path] val previousFileHash = classLoaderEntries[path]
when { when {
previousFileHash == null -> { previousFileHash == null -> {
@ -279,11 +321,11 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
* Required to prevent classes that were excluded from the no-overlap check from being loaded by contract code. * Required to prevent classes that were excluded from the no-overlap check from being loaded by contract code.
* As it can lead to non-determinism. * As it can lead to non-determinism.
*/ */
override fun loadClass(name: String?): Class<*> { override fun loadClass(name: String, resolve: Boolean): Class<*>? {
if (ignorePackages.any { name!!.startsWith(it) }) { if (ignorePackages.any { name.startsWith(it) }) {
throw ClassNotFoundException(name) throw ClassNotFoundException(name)
} }
return super.loadClass(name) return super.loadClass(name, resolve)
} }
} }
@ -293,7 +335,8 @@ class AttachmentsClassLoader(attachments: List<Attachment>,
*/ */
@VisibleForTesting @VisibleForTesting
object AttachmentsClassLoaderBuilder { object AttachmentsClassLoaderBuilder {
const val CACHE_SIZE = 16 private const val CACHE_SIZE = 16
private const val STRONG_REFERENCE_TO_CACHED_SERIALIZATION_CONTEXT = "cachedSerializationContext"
private val fallBackCache: AttachmentsClassLoaderCache = AttachmentsClassLoaderSimpleCacheImpl(CACHE_SIZE) private val fallBackCache: AttachmentsClassLoaderCache = AttachmentsClassLoaderSimpleCacheImpl(CACHE_SIZE)
@ -309,18 +352,17 @@ object AttachmentsClassLoaderBuilder {
isAttachmentTrusted: (Attachment) -> Boolean, isAttachmentTrusted: (Attachment) -> Boolean,
parent: ClassLoader = ClassLoader.getSystemClassLoader(), parent: ClassLoader = ClassLoader.getSystemClassLoader(),
attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, attachmentsClassLoaderCache: AttachmentsClassLoaderCache?,
block: (ClassLoader) -> T): T { block: (SerializationContext) -> T): T {
val attachmentIds = attachments.map(Attachment::id).toSet() val attachmentIds = attachments.mapTo(LinkedHashSet(), Attachment::id)
val cache = attachmentsClassLoaderCache ?: fallBackCache val cache = attachmentsClassLoaderCache ?: fallBackCache
val serializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { val cachedSerializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { key ->
// Create classloader and load serializers, whitelisted classes // Create classloader and load serializers, whitelisted classes
val transactionClassLoader = AttachmentsClassLoader(attachments, params, txId, isAttachmentTrusted, parent) val transactionClassLoader = AttachmentsClassLoader(attachments, key.params, txId, isAttachmentTrusted, parent)
val serializers = try { val serializers = try {
createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java, createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java,
JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION..JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION) JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION..JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION)
} } catch (ex: UnsupportedClassVersionError) {
catch(ex: UnsupportedClassVersionError) {
throw TransactionVerificationException.UnsupportedClassVersionError(txId, ex.message!!, ex) throw TransactionVerificationException.UnsupportedClassVersionError(txId, ex.message!!, ex)
} }
val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader) val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader)
@ -338,9 +380,20 @@ object AttachmentsClassLoaderBuilder {
.withoutCarpenter() .withoutCarpenter()
}) })
val serializationContext = cachedSerializationContext.withProperties(mapOf<Any, Any>(
// Duplicate the SerializationContext from the cache and give
// it these extra properties, just for this transaction.
// However, keep a strong reference to the cached SerializationContext so we can
// leverage the power of WeakReferences in the AttachmentsClassLoaderCacheImpl to figure
// out when all these have gone out of scope by the BasicVerifier going out of scope.
AMQP_ENVELOPE_CACHE_PROPERTY to HashMap<Any, Any>(AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY),
DESERIALIZATION_CACHE_PROPERTY to HashMap<Any, Any>(),
STRONG_REFERENCE_TO_CACHED_SERIALIZATION_CONTEXT to cachedSerializationContext
))
// Deserialize all relevant classes in the transaction classloader. // Deserialize all relevant classes in the transaction classloader.
return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) { return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) {
block(serializationContext.deserializationClassLoader) block(serializationContext)
} }
} }
} }
@ -352,6 +405,8 @@ object AttachmentsClassLoaderBuilder {
object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory { object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory {
internal const val attachmentScheme = "attachment" internal const val attachmentScheme = "attachment"
private val uniqueness = AtomicLong(0)
private val loadedAttachments: AttachmentsHolder = AttachmentsHolderImpl() private val loadedAttachments: AttachmentsHolder = AttachmentsHolderImpl()
override fun createURLStreamHandler(protocol: String): URLStreamHandler? { override fun createURLStreamHandler(protocol: String): URLStreamHandler? {
@ -362,14 +417,9 @@ object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory {
@Synchronized @Synchronized
fun toUrl(attachment: Attachment): URL { fun toUrl(attachment: Attachment): URL {
val proposedURL = URL(attachmentScheme, "", -1, attachment.id.toString(), AttachmentURLStreamHandler) val uniqueURL = URL(attachmentScheme, "", -1, attachment.id.toString()+ "?" + uniqueness.getAndIncrement(), AttachmentURLStreamHandler)
val existingURL = loadedAttachments.getKey(proposedURL) loadedAttachments[uniqueURL] = attachment
return if (existingURL == null) { return uniqueURL
loadedAttachments[proposedURL] = attachment
proposedURL
} else {
existingURL
}
} }
@VisibleForTesting @VisibleForTesting
@ -427,9 +477,52 @@ interface AttachmentsClassLoaderCache {
@DeleteForDJVM @DeleteForDJVM
class AttachmentsClassLoaderCacheImpl(cacheFactory: NamedCacheFactory) : SingletonSerializeAsToken(), AttachmentsClassLoaderCache { class AttachmentsClassLoaderCacheImpl(cacheFactory: NamedCacheFactory) : SingletonSerializeAsToken(), AttachmentsClassLoaderCache {
private val cache: Cache<AttachmentsClassLoaderKey, SerializationContext> = cacheFactory.buildNamed(Caffeine.newBuilder(), "AttachmentsClassLoader_cache") private class ToBeClosed(
serializationContext: SerializationContext,
val classLoaderToClose: AutoCloseable,
val cacheKey: AttachmentsClassLoaderKey,
queue: ReferenceQueue<SerializationContext>
) : WeakReference<SerializationContext>(serializationContext, queue)
private val logger = loggerFor<AttachmentsClassLoaderCacheImpl>()
private val toBeClosed = ConcurrentHashMap.newKeySet<ToBeClosed>()
private val expiryQueue = ReferenceQueue<SerializationContext>()
@Suppress("TooGenericExceptionCaught")
private fun purgeExpiryQueue() {
// Close the AttachmentsClassLoader for every SerializationContext
// that has already been garbage-collected.
while (true) {
val head = expiryQueue.poll() as? ToBeClosed ?: break
if (!toBeClosed.remove(head)) {
logger.warn("Reaped unexpected serialization context for {}", head.cacheKey)
}
try {
head.classLoaderToClose.close()
} catch (e: Exception) {
logger.warn("Error destroying serialization context for ${head.cacheKey}", e)
}
}
}
private val cache: Cache<AttachmentsClassLoaderKey, SerializationContext> = cacheFactory.buildNamed(
// Schedule for closing the deserialization classloaders when we evict them
// to release any resources they may be holding.
Caffeine.newBuilder().removalListener { key, context, _ ->
(context?.deserializationClassLoader as? AutoCloseable)?.also { autoCloseable ->
// ClassLoader to be closed once the BasicVerifier, which has a strong
// reference chain to this SerializationContext, has gone out of scope.
toBeClosed += ToBeClosed(context, autoCloseable, key!!, expiryQueue)
}
// Reap any entries which have been garbage-collected.
purgeExpiryQueue()
}, "AttachmentsClassLoader_cache"
)
override fun computeIfAbsent(key: AttachmentsClassLoaderKey, mappingFunction: Function<in AttachmentsClassLoaderKey, out SerializationContext>): SerializationContext { override fun computeIfAbsent(key: AttachmentsClassLoaderKey, mappingFunction: Function<in AttachmentsClassLoaderKey, out SerializationContext>): SerializationContext {
purgeExpiryQueue()
return cache.get(key, mappingFunction) ?: throw NullPointerException("null returned from cache mapping function") return cache.get(key, mappingFunction) ?: throw NullPointerException("null returned from cache mapping function")
} }
} }

View File

@ -0,0 +1,28 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializationMagic
import net.corda.core.utilities.ByteSequence
import java.nio.ByteBuffer
class CustomSerializationSchemeUtils {
@KeepForDJVM
companion object {
private const val SERIALIZATION_SCHEME_ID_SIZE = 4
private val PREFIX = "CUS".toByteArray()
fun getCustomSerializationMagicFromSchemeId(schemeId: Int) : SerializationMagic {
return SerializationMagic.of(PREFIX + ByteBuffer.allocate(SERIALIZATION_SCHEME_ID_SIZE).putInt(schemeId).array())
}
fun getSchemeIdIfCustomSerializationMagic(magic: SerializationMagic): Int? {
return if (magic.take(PREFIX.size) != ByteSequence.of(PREFIX)) {
null
} else {
return magic.slice(start = PREFIX.size).int
}
}
}
}

View File

@ -145,7 +145,7 @@ data class ContractUpgradeWireTransaction(
private fun upgradedContract(className: ContractClassName, classLoader: ClassLoader): UpgradedContract<ContractState, ContractState> = try { private fun upgradedContract(className: ContractClassName, classLoader: ClassLoader): UpgradedContract<ContractState, ContractState> = try {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
classLoader.loadClass(className).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract<ContractState, ContractState> Class.forName(className, false, classLoader).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract<ContractState, ContractState>
} catch (e: Exception) { } catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(id, className, e) throw TransactionVerificationException.ContractCreationError(id, className, e)
} }
@ -166,9 +166,9 @@ data class ContractUpgradeWireTransaction(
params, params,
id, id,
{ (services as ServiceHubCoreInternal).attachmentTrustCalculator.calculate(it) }, { (services as ServiceHubCoreInternal).attachmentTrustCalculator.calculate(it) },
attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { transactionClassLoader -> attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { serializationContext ->
val resolvedInput = binaryInput.deserialize() val resolvedInput = binaryInput.deserialize()
val upgradedContract = upgradedContract(upgradedContractClassName, transactionClassLoader) val upgradedContract = upgradedContract(upgradedContractClassName, serializationContext.deserializationClassLoader)
val outputState = calculateUpgradedState(resolvedInput, upgradedContract, upgradedAttachment) val outputState = calculateUpgradedState(resolvedInput, upgradedContract, upgradedAttachment)
outputState.serialize() outputState.serialize()
} }
@ -311,8 +311,7 @@ private constructor(
@CordaInternal @CordaInternal
internal fun loadUpgradedContract(upgradedContractClassName: ContractClassName, classLoader: ClassLoader): UpgradedContract<ContractState, *> { internal fun loadUpgradedContract(upgradedContractClassName: ContractClassName, classLoader: ClassLoader): UpgradedContract<ContractState, *> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
return classLoader return Class.forName(upgradedContractClassName, false, classLoader)
.loadClass(upgradedContractClassName)
.asSubclass(Contract::class.java) .asSubclass(Contract::class.java)
.getConstructor() .getConstructor()
.newInstance() as UpgradedContract<ContractState, *> .newInstance() as UpgradedContract<ContractState, *>

View File

@ -18,21 +18,25 @@ import net.corda.core.crypto.DigestService
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.BasicVerifier import net.corda.core.internal.AbstractVerifier
import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.internal.Verifier import net.corda.core.internal.Verifier
import net.corda.core.internal.castIfPossible import net.corda.core.internal.castIfPossible
import net.corda.core.internal.deserialiseCommands import net.corda.core.internal.deserialiseCommands
import net.corda.core.internal.deserialiseComponentGroup import net.corda.core.internal.deserialiseComponentGroup
import net.corda.core.internal.eagerDeserialise
import net.corda.core.internal.isUploaderTrusted import net.corda.core.internal.isUploaderTrusted
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.serialization.DeprecatedConstructorForDeserialization import net.corda.core.serialization.DeprecatedConstructorForDeserialization
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.internal.AttachmentsClassLoaderCache
import net.corda.core.serialization.internal.AttachmentsClassLoaderBuilder import net.corda.core.serialization.internal.AttachmentsClassLoaderBuilder
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import java.util.Collections.unmodifiableList import java.util.Collections.unmodifiableList
import java.util.function.Predicate import java.util.function.Predicate
import java.util.function.Supplier
/** /**
* A LedgerTransaction is derived from a [WireTransaction]. It is the result of doing the following operations: * A LedgerTransaction is derived from a [WireTransaction]. It is the result of doing the following operations:
@ -90,7 +94,7 @@ private constructor(
private val serializedInputs: List<SerializedStateAndRef>?, private val serializedInputs: List<SerializedStateAndRef>?,
private val serializedReferences: List<SerializedStateAndRef>?, private val serializedReferences: List<SerializedStateAndRef>?,
private val isAttachmentTrusted: (Attachment) -> Boolean, private val isAttachmentTrusted: (Attachment) -> Boolean,
private val verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, private val verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier,
private val attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, private val attachmentsClassLoaderCache: AttachmentsClassLoaderCache?,
val digestService: DigestService val digestService: DigestService
) : FullTransaction() { ) : FullTransaction() {
@ -114,8 +118,9 @@ private constructor(
serializedInputs: List<SerializedStateAndRef>?, serializedInputs: List<SerializedStateAndRef>?,
serializedReferences: List<SerializedStateAndRef>?, serializedReferences: List<SerializedStateAndRef>?,
isAttachmentTrusted: (Attachment) -> Boolean, isAttachmentTrusted: (Attachment) -> Boolean,
verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier,
attachmentsClassLoaderCache: AttachmentsClassLoaderCache?) : this( attachmentsClassLoaderCache: AttachmentsClassLoaderCache?
) : this(
inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt,
networkParameters, references, componentGroups, serializedInputs, serializedReferences, networkParameters, references, componentGroups, serializedInputs, serializedReferences,
isAttachmentTrusted, verifierFactory, attachmentsClassLoaderCache, DigestService.sha2_256) isAttachmentTrusted, verifierFactory, attachmentsClassLoaderCache, DigestService.sha2_256)
@ -124,8 +129,8 @@ private constructor(
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
private fun <T> protect(list: List<T>?): List<T>? { private fun <T> protect(list: List<T>): List<T> {
return list?.run { return list.run {
if (isEmpty()) { if (isEmpty()) {
emptyList() emptyList()
} else { } else {
@ -134,6 +139,8 @@ private constructor(
} }
} }
private fun <T> protectOrNull(list: List<T>?): List<T>? = list?.let(::protect)
@CordaInternal @CordaInternal
internal fun create( internal fun create(
inputs: List<StateAndRef<ContractState>>, inputs: List<StateAndRef<ContractState>>,
@ -164,9 +171,9 @@ private constructor(
privacySalt = privacySalt, privacySalt = privacySalt,
networkParameters = networkParameters, networkParameters = networkParameters,
references = references, references = references,
componentGroups = protect(componentGroups), componentGroups = protectOrNull(componentGroups),
serializedInputs = protect(serializedInputs), serializedInputs = protectOrNull(serializedInputs),
serializedReferences = protect(serializedReferences), serializedReferences = protectOrNull(serializedReferences),
isAttachmentTrusted = isAttachmentTrusted, isAttachmentTrusted = isAttachmentTrusted,
verifierFactory = ::BasicVerifier, verifierFactory = ::BasicVerifier,
attachmentsClassLoaderCache = attachmentsClassLoaderCache, attachmentsClassLoaderCache = attachmentsClassLoaderCache,
@ -176,10 +183,11 @@ private constructor(
/** /**
* This factory function will create an instance of [LedgerTransaction] * This factory function will create an instance of [LedgerTransaction]
* that will be used inside the DJVM sandbox. * that will be used for contract verification. See [BasicVerifier] and
* [DeterministicVerifier][net.corda.node.internal.djvm.DeterministicVerifier].
*/ */
@CordaInternal @CordaInternal
fun createForSandbox( fun createForContractVerify(
inputs: List<StateAndRef<ContractState>>, inputs: List<StateAndRef<ContractState>>,
outputs: List<TransactionState<ContractState>>, outputs: List<TransactionState<ContractState>>,
commands: List<CommandWithParties<CommandData>>, commands: List<CommandWithParties<CommandData>>,
@ -188,28 +196,31 @@ private constructor(
notary: Party?, notary: Party?,
timeWindow: TimeWindow?, timeWindow: TimeWindow?,
privacySalt: PrivacySalt, privacySalt: PrivacySalt,
networkParameters: NetworkParameters, networkParameters: NetworkParameters?,
references: List<StateAndRef<ContractState>>, references: List<StateAndRef<ContractState>>,
digestService: DigestService): LedgerTransaction { digestService: DigestService): LedgerTransaction {
return LedgerTransaction( return LedgerTransaction(
inputs = inputs, inputs = protect(inputs),
outputs = outputs, outputs = protect(outputs),
commands = commands, commands = protect(commands),
attachments = attachments, attachments = protect(attachments),
id = id, id = id,
notary = notary, notary = notary,
timeWindow = timeWindow, timeWindow = timeWindow,
privacySalt = privacySalt, privacySalt = privacySalt,
networkParameters = networkParameters, networkParameters = networkParameters,
references = references, references = protect(references),
componentGroups = null, componentGroups = null,
serializedInputs = null, serializedInputs = null,
serializedReferences = null, serializedReferences = null,
isAttachmentTrusted = { true }, isAttachmentTrusted = { true },
verifierFactory = ::BasicVerifier, verifierFactory = ::NoOpVerifier,
attachmentsClassLoaderCache = null, attachmentsClassLoaderCache = null,
digestService = digestService digestService = digestService
) // This check accesses input states and must run on the LedgerTransaction
// instance that is verified, not on the outer LedgerTransaction shell.
// All states must also deserialize using the correct SerializationContext.
).also(LedgerTransaction::checkBaseInvariants)
} }
} }
@ -251,11 +262,17 @@ private constructor(
getParamsWithGoo(), getParamsWithGoo(),
id, id,
isAttachmentTrusted = isAttachmentTrusted, isAttachmentTrusted = isAttachmentTrusted,
attachmentsClassLoaderCache = attachmentsClassLoaderCache) { transactionClassLoader -> attachmentsClassLoaderCache = attachmentsClassLoaderCache) { serializationContext ->
// Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader].
// Legacy check - warns if the LedgerTransaction was created incorrectly.
checkLtxForVerification()
// Create a copy of the outer LedgerTransaction which deserializes all fields using
// the serialization context (or its deserializationClassloader).
// Only the copy will be used for verification, and the outer shell will be discarded. // Only the copy will be used for verification, and the outer shell will be discarded.
// This artifice is required to preserve backwards compatibility. // This artifice is required to preserve backwards compatibility.
verifierFactory(createLtxForVerification(), transactionClassLoader) // NOTE: The Verifier creates the copies of the LedgerTransaction object now.
verifierFactory(this, serializationContext)
} }
} }
@ -272,7 +289,7 @@ private constructor(
* Node without changing either the wire format or any public APIs. * Node without changing either the wire format or any public APIs.
*/ */
@CordaInternal @CordaInternal
fun specialise(alternateVerifier: (LedgerTransaction, ClassLoader) -> Verifier): LedgerTransaction = LedgerTransaction( fun specialise(alternateVerifier: (LedgerTransaction, SerializationContext) -> Verifier): LedgerTransaction = LedgerTransaction(
inputs = inputs, inputs = inputs,
outputs = outputs, outputs = outputs,
commands = commands, commands = commands,
@ -287,7 +304,11 @@ private constructor(
serializedInputs = serializedInputs, serializedInputs = serializedInputs,
serializedReferences = serializedReferences, serializedReferences = serializedReferences,
isAttachmentTrusted = isAttachmentTrusted, isAttachmentTrusted = isAttachmentTrusted,
verifierFactory = alternateVerifier, verifierFactory = if (verifierFactory == ::NoOpVerifier) {
throw IllegalStateException("Cannot specialise transaction while verifying contracts")
} else {
alternateVerifier
},
attachmentsClassLoaderCache = attachmentsClassLoaderCache, attachmentsClassLoaderCache = attachmentsClassLoaderCache,
digestService = digestService digestService = digestService
) )
@ -319,58 +340,12 @@ private constructor(
} }
/** /**
* Create the [LedgerTransaction] instance that will be used by contract verification.
*
* This method needs to run in the special transaction attachments classloader context.
*/ */
private fun createLtxForVerification(): LedgerTransaction { private fun checkLtxForVerification() {
val serializedInputs = this.serializedInputs if (serializedInputs == null || serializedReferences == null || componentGroups == null) {
val serializedReferences = this.serializedReferences
val componentGroups = this.componentGroups
val transaction= if (serializedInputs != null && serializedReferences != null && componentGroups != null) {
// Deserialize all relevant classes in the transaction classloader.
val deserializedInputs = serializedInputs.map { it.toStateAndRef() }
val deserializedReferences = serializedReferences.map { it.toStateAndRef() }
val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true)
val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = digestService)
val authenticatedDeserializedCommands = deserializedCommands.map { cmd ->
@Suppress("DEPRECATION") // Deprecated feature.
val parties = commands.find { it.value.javaClass.name == cmd.value.javaClass.name }!!.signingParties
CommandWithParties(cmd.signers, parties, cmd.value)
}
LedgerTransaction(
inputs = deserializedInputs,
outputs = deserializedOutputs,
commands = authenticatedDeserializedCommands,
attachments = this.attachments,
id = this.id,
notary = this.notary,
timeWindow = this.timeWindow,
privacySalt = this.privacySalt,
networkParameters = this.networkParameters,
references = deserializedReferences,
componentGroups = componentGroups,
serializedInputs = serializedInputs,
serializedReferences = serializedReferences,
isAttachmentTrusted = isAttachmentTrusted,
verifierFactory = verifierFactory,
attachmentsClassLoaderCache = attachmentsClassLoaderCache,
digestService = digestService
)
} else {
// This branch is only present for backwards compatibility.
logger.warn("The LedgerTransaction should not be instantiated directly from client code. Please use WireTransaction.toLedgerTransaction." + logger.warn("The LedgerTransaction should not be instantiated directly from client code. Please use WireTransaction.toLedgerTransaction." +
"The result of the verify method might not be accurate.") "The result of the verify method might not be accurate.")
this
} }
// This check accesses input states and must be run in this context.
// It must run on the instance that is verified, not on the outer LedgerTransaction shell.
transaction.checkBaseInvariants()
return transaction
} }
/** /**
@ -740,7 +715,7 @@ private constructor(
componentGroups = null, componentGroups = null,
serializedInputs = null, serializedInputs = null,
serializedReferences = null, serializedReferences = null,
isAttachmentTrusted = { it.isUploaderTrusted() }, isAttachmentTrusted = Attachment::isUploaderTrusted,
verifierFactory = ::BasicVerifier, verifierFactory = ::BasicVerifier,
attachmentsClassLoaderCache = null attachmentsClassLoaderCache = null
) )
@ -770,7 +745,7 @@ private constructor(
componentGroups = null, componentGroups = null,
serializedInputs = null, serializedInputs = null,
serializedReferences = null, serializedReferences = null,
isAttachmentTrusted = { it.isUploaderTrusted() }, isAttachmentTrusted = Attachment::isUploaderTrusted,
verifierFactory = ::BasicVerifier, verifierFactory = ::BasicVerifier,
attachmentsClassLoaderCache = null attachmentsClassLoaderCache = null
) )
@ -838,3 +813,80 @@ private constructor(
) )
} }
} }
/**
* This is the default [Verifier] that configures Corda
* to execute [Contract.verify(LedgerTransaction)].
*
* THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE!
*/
@CordaInternal
private class BasicVerifier(
ltx: LedgerTransaction,
private val serializationContext: SerializationContext
) : AbstractVerifier(ltx, serializationContext.deserializationClassLoader) {
init {
// This is a sanity check: We should only instantiate this
// class from [LedgerTransaction.internalPrepareVerify].
require(serializationContext === SerializationFactory.defaultFactory.currentContext) {
"BasicVerifier for TX ${ltx.id} created outside its SerializationContext"
}
// Fetch these commands' signing parties from the database.
// Corda forbids database access during contract verification,
// and so we must load the commands here eagerly instead.
// THIS ALSO DESERIALISES THE COMMANDS USING THE WRONG CONTEXT
// BECAUSE THAT CONTEXT WAS CHOSEN WHEN THE LAZY MAP WAS CREATED,
// AND CHANGING THE DEFAULT CONTEXT HERE DOES NOT AFFECT IT.
ltx.commands.eagerDeserialise()
}
override val transaction: Supplier<LedgerTransaction>
get() = Supplier(::createTransaction)
private fun createTransaction(): LedgerTransaction {
// Deserialize all relevant classes using the serializationContext.
return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) {
ltx.transform { componentGroups, serializedInputs, serializedReferences ->
val deserializedInputs = serializedInputs.map(SerializedStateAndRef::toStateAndRef)
val deserializedReferences = serializedReferences.map(SerializedStateAndRef::toStateAndRef)
val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true)
val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = ltx.digestService)
val authenticatedDeserializedCommands = deserializedCommands.mapIndexed { idx, cmd ->
// Requires ltx.commands to have been deserialized already.
@Suppress("DEPRECATION") // Deprecated feature.
val parties = ltx.commands[idx].signingParties
CommandWithParties(cmd.signers, parties, cmd.value)
}
LedgerTransaction.createForContractVerify(
inputs = deserializedInputs,
outputs = deserializedOutputs,
commands = authenticatedDeserializedCommands,
attachments = ltx.attachments,
id = ltx.id,
notary = ltx.notary,
timeWindow = ltx.timeWindow,
privacySalt = ltx.privacySalt,
networkParameters = ltx.networkParameters,
references = deserializedReferences,
digestService = ltx.digestService
)
}
}
}
}
/**
* A "do nothing" [Verifier] installed for contract verification.
*
* THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE!
*/
@Suppress("unused_parameter")
@CordaInternal
private class NoOpVerifier(ltx: LedgerTransaction, serializationContext: SerializationContext) : Verifier {
// Invoking LedgerTransaction.verify() from Contract.verify(LedgerTransaction)
// will execute this function. But why would anyone do that?!
override fun verify() {}
}

View File

@ -16,14 +16,18 @@ import net.corda.core.node.ServicesForResolution
import net.corda.core.node.ZoneVersionTooLowException import net.corda.core.node.ZoneVersionTooLowException
import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.AttachmentId
import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.KeyManagementService
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializationMagic
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.ArrayDeque import java.util.*
import java.util.UUID
import java.util.regex.Pattern import java.util.regex.Pattern
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
import kotlin.collections.component1 import kotlin.collections.component1
@ -140,6 +144,41 @@ open class TransactionBuilder(
fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null) fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null)
.apply { checkSupportedHashType() } .apply { checkSupportedHashType() }
/**
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
*
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
* This is an experimental feature.
*
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
*
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
*/
@Throws(MissingContractAttachments::class)
fun toWireTransaction(services: ServicesForResolution, schemeId: Int): WireTransaction {
return toWireTransaction(services, schemeId, emptyMap()).apply { checkSupportedHashType() }
}
/**
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
*
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
* This is an experimental feature.
*
* @param [properties] a list of properties to add to the [SerializationSchemeContext] these properties can be accessed in [CustomSerializationScheme.serialize]
* when serializing the componentGroups of the wire transaction but might not be available when deserializing.
*
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
*
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
*/
@Throws(MissingContractAttachments::class)
fun toWireTransaction(services: ServicesForResolution, schemeId: Int, properties: Map<Any, Any>): WireTransaction {
val magic: SerializationMagic = getCustomSerializationMagicFromSchemeId(schemeId)
val serializationContext = SerializationDefaults.P2P_CONTEXT.withPreferredSerializationVersion(magic).withProperties(properties)
return toWireTransactionWithContext(services, serializationContext).apply { checkSupportedHashType() }
}
@CordaInternal @CordaInternal
internal fun toWireTransactionWithContext( internal fun toWireTransactionWithContext(
services: ServicesForResolution, services: ServicesForResolution,

View File

@ -15,6 +15,7 @@ import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.AttachmentId
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.DeprecatedConstructorForDeserialization import net.corda.core.serialization.DeprecatedConstructorForDeserialization
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.internal.AttachmentsClassLoaderCache
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
@ -154,7 +155,7 @@ class WireTransaction(componentGroups: List<ComponentGroup>, val privacySalt: Pr
resolveAttachment, resolveAttachment,
{ stateRef -> resolveStateRef(stateRef)?.serialize() }, { stateRef -> resolveStateRef(stateRef)?.serialize() },
{ null }, { null },
{ it.isUploaderTrusted() }, Attachment::isUploaderTrusted,
null null
) )
} }
@ -187,19 +188,26 @@ class WireTransaction(componentGroups: List<ComponentGroup>, val privacySalt: Pr
): LedgerTransaction { ): LedgerTransaction {
// Look up public keys to authenticated identities. // Look up public keys to authenticated identities.
val authenticatedCommands = commands.lazyMapped { cmd, _ -> val authenticatedCommands = commands.lazyMapped { cmd, _ ->
val parties = cmd.signers.mapNotNull { pk -> resolveIdentity(pk) } val parties = cmd.signers.mapNotNull(resolveIdentity)
CommandWithParties(cmd.signers, parties, cmd.value) CommandWithParties(cmd.signers, parties, cmd.value)
} }
// Ensure that the lazy mappings will use the correct SerializationContext.
val serializationFactory = SerializationFactory.defaultFactory
val serializationContext = serializationFactory.defaultContext
val toStateAndRef = { ssar: SerializedStateAndRef, _: Int ->
ssar.toStateAndRef(serializationFactory, serializationContext)
}
val serializedResolvedInputs = inputs.map { ref -> val serializedResolvedInputs = inputs.map { ref ->
SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref) SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref)
} }
val resolvedInputs = serializedResolvedInputs.lazyMapped { star, _ -> star.toStateAndRef() } val resolvedInputs = serializedResolvedInputs.lazyMapped(toStateAndRef)
val serializedResolvedReferences = references.map { ref -> val serializedResolvedReferences = references.map { ref ->
SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref) SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref)
} }
val resolvedReferences = serializedResolvedReferences.lazyMapped { star, _ -> star.toStateAndRef() } val resolvedReferences = serializedResolvedReferences.lazyMapped(toStateAndRef)
val resolvedAttachments = attachments.lazyMapped { att, _ -> resolveAttachment(att) ?: throw AttachmentResolutionException(att) } val resolvedAttachments = attachments.lazyMapped { att, _ -> resolveAttachment(att) ?: throw AttachmentResolutionException(att) }
@ -214,7 +222,7 @@ class WireTransaction(componentGroups: List<ComponentGroup>, val privacySalt: Pr
notary, notary,
timeWindow, timeWindow,
privacySalt, privacySalt,
resolvedNetworkParameters, resolvedNetworkParameters.toImmutable(),
resolvedReferences, resolvedReferences,
componentGroups, componentGroups,
serializedResolvedInputs, serializedResolvedInputs,
@ -318,7 +326,11 @@ class WireTransaction(componentGroups: List<ComponentGroup>, val privacySalt: Pr
* nothing about the rest. * nothing about the rest.
*/ */
internal val availableComponentNonces: Map<Int, List<SecureHash>> by lazy { internal val availableComponentNonces: Map<Int, List<SecureHash>> by lazy {
if(digestService.hashAlgorithm == SecureHash.SHA2_256) {
componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, internalIt -> digestService.componentHash(internalIt, privacySalt, it.groupIndex, internalIndex) } } componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, internalIt -> digestService.componentHash(internalIt, privacySalt, it.groupIndex, internalIndex) } }
} else {
componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, _ -> digestService.computeNonce(privacySalt, it.groupIndex, internalIndex) } }
}
} }
/** /**

View File

@ -0,0 +1,54 @@
package net.corda.core.internal.utilities
import net.corda.core.obfuscator.XorOutputStream
import java.net.URL
import java.nio.file.Files
import java.nio.file.Paths
import java.util.zip.ZipEntry
import java.util.zip.ZipOutputStream
object TestResourceWriter {
private val externalZipBombUrls = arrayOf(
URL("https://www.bamsoftware.com/hacks/zipbomb/zbsm.zip"),
URL("https://www.bamsoftware.com/hacks/zipbomb/zblg.zip"),
URL("https://www.bamsoftware.com/hacks/zipbomb/zbxl.zip")
)
@JvmStatic
@Suppress("NestedBlockDepth", "MagicNumber")
fun main(vararg args : String) {
for(arg in args) {
/**
* Download zip bombs
*/
for(url in externalZipBombUrls) {
url.openStream().use { inputStream ->
val destination = Paths.get(arg).resolve(Paths.get(url.path + ".xor").fileName)
Files.newOutputStream(destination).buffered().let(::XorOutputStream).use { outputStream ->
inputStream.copyTo(outputStream)
}
}
}
/**
* Create a jar archive with a huge manifest file, used in unit tests to check that it is also identified as a zip bomb.
* This is because {@link java.util.jar.JarInputStream}
* <a href="https://github.com/openjdk/jdk/blob/4dedba9ebe11750f4b39c41feb4a4314ccdb3a08/src/java.base/share/classes/java/util/jar/JarInputStream.java#L95">eagerly loads the manifest file in memory</a>
* which would make such a jar dangerous if used as an attachment
*/
val destination = Paths.get(arg).resolve(Paths.get("big-manifest.jar.xor").fileName)
ZipOutputStream(XorOutputStream((Files.newOutputStream(destination).buffered()))).use { zos ->
val zipEntry = ZipEntry("MANIFEST.MF")
zipEntry.method = ZipEntry.DEFLATED
zos.putNextEntry(zipEntry)
val buffer = ByteArray(0x100000) { 0x0 }
var written = 0L
while(written < 10_000_000_000) {
zos.write(buffer)
written += buffer.size
}
zos.closeEntry()
}
}
}
}

View File

@ -0,0 +1,30 @@
package net.corda.core.obfuscator
import java.io.FilterInputStream
import java.io.InputStream
@Suppress("MagicNumber")
class XorInputStream(private val source : InputStream) : FilterInputStream(source) {
var prev : Int = 0
override fun read(): Int {
prev = source.read() xor prev
return prev - 0x80
}
override fun read(buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
override fun read(buffer: ByteArray, off: Int, len: Int): Int {
var read = 0
while(true) {
val b = source.read()
if(b < 0) break
buffer[off + read++] = ((b xor prev) - 0x80).toByte()
prev = b
if(read == len) break
}
return read
}
}

View File

@ -0,0 +1,30 @@
package net.corda.core.obfuscator
import java.io.FilterOutputStream
import java.io.OutputStream
@Suppress("MagicNumber")
class XorOutputStream(private val destination : OutputStream) : FilterOutputStream(destination) {
var prev : Int = 0
override fun write(byte: Int) {
val b = (byte + 0x80) xor prev
destination.write(b)
prev = b
}
override fun write(buffer: ByteArray) {
write(buffer, 0, buffer.size)
}
override fun write(buffer: ByteArray, off: Int, len: Int) {
var written = 0
while(true) {
val b = (buffer[written] + 0x80) xor prev
destination.write(b)
prev = b
++written
if(written == len) break
}
}
}

View File

@ -8,5 +8,17 @@ the context of a node. However, as everything else depends on the core module, w
this module. Therefore, any tests that require further Corda dependencies need to be defined in the module this module. Therefore, any tests that require further Corda dependencies need to be defined in the module
`core-tests`, which has the full set of dependencies including `node-driver`. `core-tests`, which has the full set of dependencies including `node-driver`.
# ZipBomb tests
There is a unit test that checks the zip bomb detector in `net.corda.core.internal.utilities.ZipBombDetector` works correctly.
This test (`core/src/test/kotlin/net/corda/core/internal/utilities/ZipBombDetectorTest.kt`) uses real zip bombs, provided by `https://www.bamsoftware.com/hacks/zipbomb/`.
As it is undesirable to have unit test depends on external internet resources we do not control, those files are included as resources in
`core/src/test/resources/zip/`, however some Windows antivirus software correctly identifies those files as zip bombs,
raising an alert to the user. To mitigate this, those files have been obfuscated using `net.corda.core.obfuscator.XorOutputStream`
(which simply XORs every byte of the file with the previous one, except for the first byte that is XORed with zero)
to prevent antivirus software from detecting them as zip bombs and are de-obfuscated on the fly in unit tests using
`net.corda.core.obfuscator.XorInputStream`.
There is a dedicated Gradle task to re-download and re-obfuscate all the test resource files named `writeTestResources`,
its source code is in `core/src/obfuscator/kotlin/net/corda/core/internal/utilities/TestResourceWriter.kt`

View File

@ -1,33 +1,19 @@
package net.corda.core.crypto package net.corda.core.crypto
import net.corda.core.crypto.internal.DigestAlgorithmFactory import net.corda.core.crypto.internal.DigestAlgorithmFactory
import org.bouncycastle.crypto.digests.Blake2sDigest import net.corda.core.internal.BLAKE2s256DigestAlgorithm
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class Blake2s256DigestServiceTest { class Blake2s256DigestServiceTest {
class BLAKE2s256DigestService : DigestAlgorithm {
override val algorithm = "BLAKE_TEST"
override val digestLength = 32
override fun digest(bytes: ByteArray): ByteArray {
val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray())
blake2s256.reset()
blake2s256.update(bytes, 0, bytes.size)
val hash = ByteArray(digestLength)
blake2s256.doFinal(hash, 0)
return hash
}
}
private val service = DigestService("BLAKE_TEST") private val service = DigestService("BLAKE_TEST")
@Before @Before
fun before() { fun before() {
DigestAlgorithmFactory.registerClass(BLAKE2s256DigestService::class.java.name) DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)

View File

@ -0,0 +1,36 @@
package net.corda.core.internal
import net.corda.core.crypto.DigestAlgorithm
import net.corda.core.crypto.SecureHash
import org.bouncycastle.crypto.digests.Blake2sDigest
/**
* A set of custom hash algorithms
*/
open class BLAKE2s256DigestAlgorithm : DigestAlgorithm {
override val algorithm = "BLAKE_TEST"
override val digestLength = 32
protected fun blake2sHash(bytes: ByteArray): ByteArray {
val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray())
blake2s256.reset()
blake2s256.update(bytes, 0, bytes.size)
val hash = ByteArray(digestLength)
blake2s256.doFinal(hash, 0)
return hash
}
override fun digest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
}
class SHA256BLAKE2s256DigestAlgorithm : BLAKE2s256DigestAlgorithm() {
override val algorithm = "SHA256-BLAKE2S256-TEST"
override fun digest(bytes: ByteArray): ByteArray = SecureHash.hashAs(SecureHash.SHA2_256, bytes).bytes
override fun componentDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
override fun nonceDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
}

View File

@ -5,10 +5,12 @@ import net.corda.core.crypto.DigestService
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.internal.AttachmentsClassLoaderCache
import net.corda.core.transactions.ComponentGroup import net.corda.core.transactions.ComponentGroup
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import java.util.function.Supplier
/** /**
* A set of functions in core:test that allows testing of core internal classes in the core-tests project. * A set of functions in core:test that allows testing of core internal classes in the core-tests project.
@ -18,6 +20,7 @@ fun WireTransaction.accessGroupHashes() = this.groupHashes
fun WireTransaction.accessGroupMerkleRoots() = this.groupsMerkleRoots fun WireTransaction.accessGroupMerkleRoots() = this.groupsMerkleRoots
fun WireTransaction.accessAvailableComponentHashes() = this.availableComponentHashes fun WireTransaction.accessAvailableComponentHashes() = this.availableComponentHashes
fun WireTransaction.accessAvailableComponentNonces() = this.availableComponentNonces
@Suppress("LongParameterList") @Suppress("LongParameterList")
fun createLedgerTransaction( fun createLedgerTransaction(
@ -37,7 +40,17 @@ fun createLedgerTransaction(
isAttachmentTrusted: (Attachment) -> Boolean, isAttachmentTrusted: (Attachment) -> Boolean,
attachmentsClassLoaderCache: AttachmentsClassLoaderCache, attachmentsClassLoaderCache: AttachmentsClassLoaderCache,
digestService: DigestService = DigestService.default digestService: DigestService = DigestService.default
): LedgerTransaction = LedgerTransaction.create(inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, attachmentsClassLoaderCache, digestService) ): LedgerTransaction = LedgerTransaction.create(
inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, attachmentsClassLoaderCache, digestService
).specialise(::PassthroughVerifier)
fun createContractCreationError(txId: SecureHash, contractClass: String, cause: Throwable) = TransactionVerificationException.ContractCreationError(txId, contractClass, cause) fun createContractCreationError(txId: SecureHash, contractClass: String, cause: Throwable) = TransactionVerificationException.ContractCreationError(txId, contractClass, cause)
fun createContractRejection(txId: SecureHash, contract: Contract, cause: Throwable) = TransactionVerificationException.ContractRejection(txId, contract, cause) fun createContractRejection(txId: SecureHash, contract: Contract, cause: Throwable) = TransactionVerificationException.ContractRejection(txId, contract, cause)
/**
* Verify the [LedgerTransaction] we already have.
*/
private class PassthroughVerifier(ltx: LedgerTransaction, context: SerializationContext) : AbstractVerifier(ltx, context.deserializationClassLoader) {
override val transaction: Supplier<LedgerTransaction>
get() = Supplier { ltx }
}

View File

@ -0,0 +1,70 @@
package net.corda.core.internal.utilities
import net.corda.core.obfuscator.XorInputStream
import org.junit.Assert
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class ZipBombDetectorTest(private val case : TestCase) {
enum class TestCase(
val description : String,
val zipResource : String,
val maxUncompressedSize : Long,
val maxCompressionRatio : Float,
val expectedOutcome : Boolean
) {
LEGIT_JAR("This project's jar file", "zip/core.jar", 128_000, 10f, false),
// This is not detected as a zip bomb as ZipInputStream is unable to read all of its entries
// (https://stackoverflow.com/questions/69286786/zipinputstream-cannot-parse-a-281-tb-zip-bomb),
// so the total uncompressed size doesn't exceed maxUncompressedSize
SMALL_BOMB(
"A large (5.5 GB) zip archive",
"zip/zbsm.zip.xor", 64_000_000, 10f, false),
// Decreasing maxUncompressedSize leads to a successful detection
SMALL_BOMB2(
"A large (5.5 GB) zip archive, with 1MB maxUncompressedSize",
"zip/zbsm.zip.xor", 1_000_000, 10f, true),
// ZipInputStream is also unable to read all entries of zblg.zip, but since the first one is already bigger than 4GB,
// that is enough to exceed maxUncompressedSize
LARGE_BOMB(
"A huge (281 TB) Zip bomb, this is the biggest possible non-recursive non-Zip64 archive",
"zip/zblg.zip.xor", 64_000_000, 10f, true),
//Same for this, but its entries are 22GB each
EXTRA_LARGE_BOMB(
"A humongous (4.5 PB) Zip64 bomb",
"zip/zbxl.zip.xor", 64_000_000, 10f, true),
//This is a jar file containing a single 10GB manifest
BIG_MANIFEST(
"A jar file with a huge manifest",
"zip/big-manifest.jar.xor", 64_000_000, 10f, true);
override fun toString() = description
}
companion object {
@JvmStatic
@Parameterized.Parameters(name = "{0}")
fun generateTestCases(): Collection<*> {
return TestCase.values().toList()
}
}
@Test(timeout=10_000)
fun test() {
(javaClass.classLoader.getResourceAsStream(case.zipResource) ?:
throw IllegalStateException("Missing test resource file ${case.zipResource}"))
.buffered()
.let(::XorInputStream)
.let {
Assert.assertEquals(case.expectedOutcome, ZipBombDetector.scanZip(it, case.maxUncompressedSize, case.maxCompressionRatio))
}
}
}

View File

@ -0,0 +1,50 @@
package net.corda.core.obfuscator
import org.junit.Assert
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.security.DigestInputStream
import java.security.DigestOutputStream
import java.security.MessageDigest
import java.util.*
@RunWith(Parameterized::class)
class XorStreamTest(private val size : Int) {
private val random = Random(0)
companion object {
@JvmStatic
@Parameterized.Parameters(name = "{0}")
fun generateTestCases(): Collection<*> {
return listOf(0, 16, 31, 127, 1000, 1024)
}
}
@Test(timeout = 5000)
fun test() {
val baos = ByteArrayOutputStream(size)
val md = MessageDigest.getInstance("MD5")
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
DigestOutputStream(XorOutputStream(baos), md).use { outputStream ->
var written = 0
while(written < size) {
random.nextBytes(buffer)
val bytesToWrite = (size - written).coerceAtMost(buffer.size)
outputStream.write(buffer, 0, bytesToWrite)
written += bytesToWrite
}
}
val digest = md.digest()
md.reset()
DigestInputStream(XorInputStream(ByteArrayInputStream(baos.toByteArray())), md).use { inputStream ->
while(true) {
val read = inputStream.read(buffer)
if(read <= 0) break
}
}
Assert.assertArrayEquals(digest, md.digest())
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1162,6 +1162,7 @@
<ID>MatchingDeclarationName:NamedCache.kt$net.corda.core.internal.NamedCache.kt</ID> <ID>MatchingDeclarationName:NamedCache.kt$net.corda.core.internal.NamedCache.kt</ID>
<ID>MatchingDeclarationName:NetParams.kt$net.corda.netparams.NetParams.kt</ID> <ID>MatchingDeclarationName:NetParams.kt$net.corda.netparams.NetParams.kt</ID>
<ID>MatchingDeclarationName:NetworkParametersServiceInternal.kt$net.corda.core.internal.NetworkParametersServiceInternal.kt</ID> <ID>MatchingDeclarationName:NetworkParametersServiceInternal.kt$net.corda.core.internal.NetworkParametersServiceInternal.kt</ID>
<ID>MatchingDeclarationName:NotaryQueries.kt$net.corda.nodeapi.notary.NotaryQueries.kt</ID>
<ID>MatchingDeclarationName:OGSwapPricingCcpExample.kt$net.corda.vega.analytics.example.OGSwapPricingCcpExample.kt</ID> <ID>MatchingDeclarationName:OGSwapPricingCcpExample.kt$net.corda.vega.analytics.example.OGSwapPricingCcpExample.kt</ID>
<ID>MatchingDeclarationName:OGSwapPricingExample.kt$net.corda.vega.analytics.example.OGSwapPricingExample.kt</ID> <ID>MatchingDeclarationName:OGSwapPricingExample.kt$net.corda.vega.analytics.example.OGSwapPricingExample.kt</ID>
<ID>MatchingDeclarationName:PlatformSecureRandom.kt$net.corda.core.crypto.internal.PlatformSecureRandom.kt</ID> <ID>MatchingDeclarationName:PlatformSecureRandom.kt$net.corda.core.crypto.internal.PlatformSecureRandom.kt</ID>

View File

@ -1,8 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
NODE_LIST=("dockerNode1" "dockerNode2" "dockerNode3") NODE_LIST=("dockerNode1" "dockerNode2" "dockerNode3")
NETWORK_NAME=mininet NETWORK_NAME=mininet
CORDAPP_VERSION="4.6-SNAPSHOT" CORDAPP_VERSION="4.8-SNAPSHOT"
DOCKER_IMAGE_VERSION="corda-zulu-4.6-snapshot" DOCKER_IMAGE_VERSION="corda-zulu-4.8-snapshot"
mkdir cordapps mkdir cordapps
rm -f cordapps/* rm -f cordapps/*

View File

@ -1,10 +1,11 @@
FROM azul/zulu-openjdk:8u192 FROM azul/zulu-openjdk:8u312
## Remove Azul Zulu repo, as it is gone by now ## Remove Azul Zulu repo, as it is gone by now
RUN rm -rf /etc/apt/sources.list.d/zulu.list RUN rm -rf /etc/apt/sources.list.d/zulu.list
## Add packages, clean cache, create dirs, create corda user and change ownership ## Add packages, clean cache, create dirs, create corda user and change ownership
RUN apt-get update && \ RUN apt-get update && \
apt-mark hold zulu8-jdk && \
apt-get -y upgrade && \ apt-get -y upgrade && \
apt-get -y install bash curl unzip && \ apt-get -y install bash curl unzip && \
rm -rf /var/lib/apt/lists/* && \ rm -rf /var/lib/apt/lists/* && \

View File

@ -1,7 +1,8 @@
FROM azul/zulu-openjdk:8u192 FROM azul/zulu-openjdk:8u312
## Add packages, clean cache, create dirs, create corda user and change ownership ## Add packages, clean cache, create dirs, create corda user and change ownership
RUN apt-get update && \ RUN apt-get update && \
apt-mark hold zulu8-jdk && \
apt-get -y upgrade && \ apt-get -y upgrade && \
apt-get -y install bash curl unzip netstat lsof telnet netcat && \ apt-get -y install bash curl unzip netstat lsof telnet netcat && \
rm -rf /var/lib/apt/lists/* && \ rm -rf /var/lib/apt/lists/* && \

View File

@ -5,6 +5,10 @@ apply plugin: 'net.corda.plugins.publish-utils'
apply plugin: 'maven-publish' apply plugin: 'maven-publish'
apply plugin: 'com.jfrog.artifactory' apply plugin: 'com.jfrog.artifactory'
dependencies {
compile rootProject
}
def internalPackagePrefixes(sourceDirs) { def internalPackagePrefixes(sourceDirs) {
def prefixes = [] def prefixes = []
// Kotlin allows packages to deviate from the directory structure, but let's assume they don't: // Kotlin allows packages to deviate from the directory structure, but let's assume they don't:
@ -36,10 +40,13 @@ task dokkaJavadoc(type: org.jetbrains.dokka.gradle.DokkaTask) {
} }
[dokka, dokkaJavadoc].collect { [dokka, dokkaJavadoc].collect {
it.configure { it.configuration {
moduleName = 'corda' moduleName = 'corda'
processConfigurations = ['compile'] dokkaSourceDirs.collect { sourceDir ->
sourceDirs = dokkaSourceDirs sourceRoot {
path = sourceDir.path
}
}
includes = ['packages.md'] includes = ['packages.md']
jdkVersion = 8 jdkVersion = 8
externalDocumentationLink { externalDocumentationLink {
@ -52,7 +59,7 @@ task dokkaJavadoc(type: org.jetbrains.dokka.gradle.DokkaTask) {
url = new URL("https://www.bouncycastle.org/docs/docs1.5on/") url = new URL("https://www.bouncycastle.org/docs/docs1.5on/")
} }
internalPackagePrefixes.collect { packagePrefix -> internalPackagePrefixes.collect { packagePrefix ->
packageOptions { perPackageOption {
prefix = packagePrefix prefix = packagePrefix
suppress = true suppress = true
} }

View File

@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.hours import net.corda.core.utilities.hours
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.nodeapi.internal.installDevNodeCaCertPath import net.corda.nodeapi.internal.installDevNodeCaCertPath
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.SerializationFactoryImpl import net.corda.serialization.internal.SerializationFactoryImpl
@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.testing.internal.IS_OPENJ9 import net.corda.testing.internal.IS_OPENJ9
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Assume import org.junit.Assume
@ -74,10 +80,19 @@ import java.security.PrivateKey
import java.security.cert.CertPath import java.security.cert.CertPath
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.* import java.util.*
import javax.net.ssl.* import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.* import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
class X509UtilitiesTest { class X509UtilitiesTest {
private companion object { private companion object {
@ -295,15 +310,10 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom()) context.init(keyManagers, trustManagers, newSecureRandom())
@ -388,15 +398,8 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get() val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustStore = sslConfig.trustStore.get() val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get())
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val sslServerContext = SslContextBuilder val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory) .forServer(keyManagerFactory)

View File

@ -54,6 +54,9 @@ dependencies {
testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}" testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}"
testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${junit_jupiter_version}" testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${junit_jupiter_version}"
testRuntimeOnly "org.junit.platform:junit-platform-launcher:${junit_platform_version}" testRuntimeOnly "org.junit.platform:junit-platform-launcher:${junit_platform_version}"
testCompile project(':node-driver')
// Unit testing helpers. // Unit testing helpers.
testCompile "org.assertj:assertj-core:$assertj_version" testCompile "org.assertj:assertj-core:$assertj_version"
testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"

View File

@ -5,7 +5,6 @@ import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransportFromList
import net.corda.nodeapi.internal.config.MessagingServerConnectionConfiguration import net.corda.nodeapi.internal.config.MessagingServerConnectionConfiguration
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.*
@ -25,7 +24,9 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
private val confirmationWindowSize: Int = -1, private val confirmationWindowSize: Int = -1,
private val messagingServerConnectionConfig: MessagingServerConnectionConfiguration? = null, private val messagingServerConnectionConfig: MessagingServerConnectionConfiguration? = null,
private val backupServerAddressPool: List<NetworkHostAndPort> = emptyList(), private val backupServerAddressPool: List<NetworkHostAndPort> = emptyList(),
private val failoverCallback: ((FailoverEventType) -> Unit)? = null private val failoverCallback: ((FailoverEventType) -> Unit)? = null,
private val threadPoolName: String = "ArtemisClient",
private val trace: Boolean = false
) : ArtemisSessionProvider { ) : ArtemisSessionProvider {
companion object { companion object {
private val log = loggerFor<ArtemisMessagingClient>() private val log = loggerFor<ArtemisMessagingClient>()
@ -40,8 +41,10 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
override fun start(): Started = synchronized(this) { override fun start(): Started = synchronized(this) {
check(started == null) { "start can't be called twice" } check(started == null) { "start can't be called twice" }
val tcpTransport = p2pConnectorTcpTransport(serverAddress, config) val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = p2pConnectorTcpTransportFromList(backupServerAddressPool, config) val backupTransports = backupServerAddressPool.mapIndexed { index, address ->
p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace)
}
log.info("Connecting to message broker: $serverAddress") log.info("Connecting to message broker: $serverAddress")
if (backupTransports.isNotEmpty()) { if (backupTransports.isNotEmpty()) {
@ -50,8 +53,6 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
// If back-up artemis addresses are configured, the locator will be created using HA mode. // If back-up artemis addresses are configured, the locator will be created using HA mode.
@Suppress("SpreadOperator") @Suppress("SpreadOperator")
val locator = ActiveMQClient.createServerLocator(backupTransports.isNotEmpty(), *(listOf(tcpTransport) + backupTransports).toTypedArray()).apply { val locator = ActiveMQClient.createServerLocator(backupTransports.isNotEmpty(), *(listOf(tcpTransport) + backupTransports).toTypedArray()).apply {
// Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this
// would be the default and the two lines below can be deleted.
connectionTTL = 60000 connectionTTL = 60000
clientFailureCheckPeriod = 30000 clientFailureCheckPeriod = 30000
callFailoverTimeout = java.lang.Long.getLong(CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME, CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT) callFailoverTimeout = java.lang.Long.getLong(CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME, CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT)

View File

@ -1,19 +1,20 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path import javax.net.ssl.TrustManagerFactory
// This avoids internal types from leaking in the public API. The "external" ArtemisTcpTransport delegates to this internal one. @Suppress("LongParameterList")
class ArtemisTcpTransport { class ArtemisTcpTransport {
companion object { companion object {
val CIPHER_SUITES = listOf( val CIPHER_SUITES = listOf(
@ -23,65 +24,52 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2") val TLS_VERSIONS = listOf("TLSv1.2")
internal fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort) = mapOf( const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
// Basic TCP target details. const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
TransportConstants.HOST_PROP_NAME to hostAndPort.host, const val TRACE_NAME = "Corda-Trace"
TransportConstants.PORT_PROP_NAME to hostAndPort.port, const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
// Turn on AMQP support, which needs the protocol jar on the classpath. // Turn on AMQP support, which needs the protocol jar on the classpath.
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats. // It does not use AMQP messages for its own messages e.g. topology and heartbeats.
// TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications. private const val P2P_PROTOCOLS = "CORE,AMQP"
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP", private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
// Basic TCP target details.
TransportConstants.HOST_PROP_NAME to hostAndPort.host,
TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to // turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336) //hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false) TransportConstants.DIRECT_DELIVER to false)
internal val defaultSSLOptions = mapOf( private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","), if (keyStore != null || trustStore != null) {
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(",")) options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
private fun SslConfiguration.toTransportOptions(): Map<String, Any> {
val options = mutableMapOf<String, Any>()
(keyStore to trustStore).addToTransportOptions(options)
return options
} }
private fun Pair<FileBasedCertificateStoreSupplier?, FileBasedCertificateStoreSupplier?>.addToTransportOptions(options: MutableMap<String, Any>) {
val keyStore = first
val trustStore = second
keyStore?.let { keyStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path)) options[TransportConstants.KEYSTORE_PROVIDER_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
trustStore?.let { trustStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path)) options[TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
} }
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf( private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to trustStoreProvider, TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to trustStoreProvider,
@ -95,86 +83,164 @@ class ArtemisTcpTransport {
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to keyStorePassword, TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to keyStorePassword,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to false) TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to false)
internal val acceptorFactoryClassName = "org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory" fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
internal val connectorFactoryClassName = NettyConnectorFactory::class.java.name config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration { enableSSL: Boolean = true,
threadPoolName: String = "P2PServer",
return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false) trace: Boolean = false,
} remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): TransportConfiguration {
return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false, keyStoreProvider = keyStoreProvider)
}
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration {
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
if (enableSSL) { if (enableSSL) {
options.putAll(defaultSSLOptions) config?.addToTransportOptions(options)
(keyStore to trustStore).addToTransportOptions(options)
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
} }
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections return createAcceptorTransport(
return TransportConfiguration(acceptorFactoryClassName, options) hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
} }
@Suppress("LongParameterList") fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false, keyStoreProvider: String? = null): TransportConfiguration { config: MutualSslConfiguration?,
enableSSL: Boolean = true,
val options = defaultArtemisOptions(hostAndPort).toMutableMap() threadPoolName: String = "P2PClient",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) { if (enableSSL) {
options.putAll(defaultSSLOptions) config?.addToTransportOptions(options)
(keyStore to trustStore).addToTransportOptions(options)
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
keyStoreProvider?.let { options.put(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, keyStoreProvider) }
} }
return TransportConfiguration(connectorFactoryClassName, options) return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
} }
fun p2pConnectorTcpTransportFromList(hostAndPortList: List<NetworkHostAndPort>, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): List<TransportConfiguration> = hostAndPortList.map { fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
p2pConnectorTcpTransport(it, config, enableSSL, keyStoreProvider) config: BrokerRpcSslOptions?,
} enableSSL: Boolean = true,
threadPoolName: String = "RPCServer",
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true): TransportConfiguration { trace: Boolean = false,
val options = defaultArtemisOptions(hostAndPort).toMutableMap() remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem() config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
options.putAll(defaultSSLOptions)
} }
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads)
return TransportConfiguration(acceptorFactoryClassName, options)
} }
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: ClientRpcSslOptions?, enableSSL: Boolean = true): TransportConfiguration { fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
val options = defaultArtemisOptions(hostAndPort).toMutableMap() config: ClientRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem() config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
options.putAll(defaultSSLOptions)
} }
return TransportConfiguration(connectorFactoryClassName, options) return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
} }
fun rpcConnectorTcpTransportsFromList(hostAndPortList: List<NetworkHostAndPort>, config: ClientRpcSslOptions?, enableSSL: Boolean = true): List<TransportConfiguration> = hostAndPortList.map { fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
rpcConnectorTcpTransport(it, config, enableSSL) config: SslConfiguration,
threadPoolName: String = "Internal-RPCClient",
trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null)
} }
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration { fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
return TransportConfiguration(connectorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + config.toTransportOptions() + asMap(keyStoreProvider)) config: SslConfiguration,
threadPoolName: String = "Internal-RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
threadPoolName,
trace,
remotingThreads
)
} }
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration { private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
return TransportConfiguration(acceptorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + protocols: String,
config.toTransportOptions() + (TransportConstants.HANDSHAKE_TIMEOUT to 0) + asMap(keyStoreProvider)) options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory.
//
// This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in
// Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort,
protocols,
options,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
} }
private fun asMap(keyStoreProvider: String?): Map<String, String> { private fun createConnectorTransport(hostAndPort: NetworkHostAndPort,
return keyStoreProvider?.let {mutableMapOf(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to it)} ?: emptyMap() protocols: String,
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
return createTransport(
CordaNettyConnectorFactory::class.java.name,
hostAndPort,
protocols,
options,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
}
private fun createTransport(className: String,
hostAndPort: NetworkHostAndPort,
protocols: String,
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) {
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
}
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace
return TransportConfiguration(className, options)
} }
} }
} }

View File

@ -1,8 +1,14 @@
@file:JvmName("ArtemisUtils") @file:JvmName("ArtemisUtils")
package net.corda.nodeapi.internal package net.corda.nodeapi.internal
import net.corda.core.internal.declaredField
import org.apache.activemq.artemis.utils.actors.ProcessorBase
import java.nio.file.FileSystems import java.nio.file.FileSystems
import java.nio.file.Path import java.nio.file.Path
import java.util.concurrent.Executor
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicInteger
/** /**
* Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use. * Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use.
@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) {
require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" } require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" }
} }
val Executor.rootExecutor: Executor get() {
var executor: Executor = this
while (executor is ProcessorBase<*>) {
executor = executor.declaredField<Executor>("delegate").value
}
return executor
}
fun Executor.setThreadPoolName(threadPoolName: String) {
(rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) }
}
private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory {
companion object {
private val poolId = AtomicInteger(0)
}
private val prefix = "$poolName-${poolId.incrementAndGet()}-"
private val nextId = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
val thread = delegate.newThread(r)
thread.name = "$prefix${nextId.incrementAndGet()}"
return thread
}
}

View File

@ -0,0 +1,73 @@
package net.corda.nodeapi.internal
import io.netty.channel.ChannelPipeline
import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler
import org.apache.activemq.artemis.core.protocol.core.impl.ActiveMQClientProtocolManager
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnector
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ClientConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ClientProtocolManager
import org.apache.activemq.artemis.spi.core.remoting.Connector
import org.apache.activemq.artemis.spi.core.remoting.ConnectorFactory
import org.apache.activemq.artemis.utils.ConfigurationHelper
import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService
class CordaNettyConnectorFactory : ConnectorFactory {
override fun createConnector(configuration: MutableMap<String, Any>?,
handler: BufferHandler?,
listener: ClientConnectionLifeCycleListener?,
closeExecutor: Executor,
threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService,
protocolManager: ClientProtocolManager?): Connector {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration)
setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName)
val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
return NettyConnector(
configuration,
handler,
listener,
closeExecutor,
threadPool,
scheduledThreadPool,
MyClientProtocolManager("$threadPoolName-netty", trace)
)
}
override fun isReliable(): Boolean = false
override fun getDefaults(): Map<String?, Any?> = NettyConnector.DEFAULT_CONFIG
private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) {
threadPool.setThreadPoolName("$name-artemis")
// Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool
// and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names.
if (threadPool.rootExecutor !== closeExecutor.rootExecutor) {
closeExecutor.setThreadPoolName("$name-artemis-closer")
}
// The scheduler is separate
scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler")
}
private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() {
override fun addChannelHandlers(pipeline: ChannelPipeline) {
applyThreadPoolName()
super.addChannelHandlers(pipeline)
if (trace) {
pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
}
}
/**
* [NettyConnector.start] does not provide a way to configure the thread pool name, so we modify the thread name accordingly.
*/
private fun applyThreadPoolName() {
with(Thread.currentThread()) {
name = name.replace("nioEventLoopGroup", threadPoolName) // pool and thread numbers are preserved
}
}
}
}

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -5,22 +5,24 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import io.netty.channel.EventLoop import io.netty.channel.EventLoop
import io.netty.channel.EventLoopGroup import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY
import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress
import net.corda.nodeapi.internal.ArtemisSessionProvider import net.corda.nodeapi.internal.ArtemisSessionProvider
import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
@ -29,6 +31,8 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientSession import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.MDC import org.slf4j.MDC
import rx.Subscription import rx.Subscription
import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.ScheduledFuture import java.util.concurrent.ScheduledFuture
@ -51,10 +55,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null, private val bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean, trace: Boolean,
sslHandshakeTimeout: Long?, sslHandshakeTimeout: Duration?,
private val bridgeConnectionTTLSeconds: Int) : BridgeManager { private val bridgeConnectionTTLSeconds: Int) : BridgeManager {
private val lock = ReentrantLock() private val lock = ReentrantLock()
@ -69,16 +73,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
override val enableSNI: Boolean, override val enableSNI: Boolean,
override val sourceX500Name: String? = null, override val sourceX500Name: String? = null,
override val trace: Boolean, override val trace: Boolean,
private val _sslHandshakeTimeout: Long?) : AMQPConfiguration { private val _sslHandshakeTimeout: Duration?) : AMQPConfiguration {
override val sslHandshakeTimeout: Long override val sslHandshakeTimeout: Duration
get() = _sslHandshakeTimeout ?: super.sslHandshakeTimeout get() = _sslHandshakeTimeout ?: super.sslHandshakeTimeout
} }
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout) private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout)
private var sharedEventLoopGroup: EventLoopGroup? = null private var sharedEventLoopGroup: EventLoopGroup? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
private var artemis: ArtemisSessionProvider? = null private var artemis: ArtemisSessionProvider? = null
companion object { companion object {
private val log = contextLogger()
private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads" private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads"
@ -95,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
* however Artemis and the remote Corda instanced will deduplicate these messages. * however Artemis and the remote Corda instanced will deduplicate these messages.
*/ */
@Suppress("TooManyFunctions") @Suppress("TooManyFunctions")
private class AMQPBridge(val sourceX500Name: String, private inner class AMQPBridge(val sourceX500Name: String,
val queueName: String, val queueName: String,
val targets: List<NetworkHostAndPort>, val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration, private val amqpConfig: AMQPConfiguration) {
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService?,
private val bridgeConnectionTTLSeconds: Int) {
companion object {
private val log = contextLogger()
}
private fun withMDC(block: () -> Unit) { private fun withMDC(block: () -> Unit) {
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>() val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
@ -114,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName) MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name) MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block() block()
} finally { } finally {
@ -132,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) val amqpClient = AMQPClient(
targets,
allowedRemoteLegalNames,
amqpConfig,
AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!)
)
private var session: ClientSession? = null private var session: ClientSession? = null
private var consumer: ClientConsumer? = null private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null private var connectedSubscription: Subscription? = null
@Volatile @Volatile
private var messagesReceived: Boolean = false private var messagesReceived: Boolean = false
private val eventLoop: EventLoop = sharedEventGroup.next() private val eventLoop: EventLoop = sharedEventLoopGroup!!.next()
private var artemisState: ArtemisState = ArtemisState.STOPPED private var artemisState: ArtemisState = ArtemisState.STOPPED
set(value) { set(value) {
logDebugWithMDC { "State change $field to $value" } logDebugWithMDC { "State change $field to $value" }
@ -150,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private var scheduledExecutorService: ScheduledExecutorService private var scheduledExecutorService: ScheduledExecutorService
= Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build()) = Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build())
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) { private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) {
val runnable = { val runnable = {
synchronized(artemis) { synchronized(artemis!!) {
try { try {
val precedingState = artemisState val precedingState = artemisState
artemisState.pending?.cancel(false) artemisState.pending?.cancel(false)
@ -229,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
ArtemisState.STOPPING ArtemisState.STOPPING
} }
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe() connectedSubscription?.unsubscribe()
connectedSubscription = null connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -241,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) { if (connected) {
logInfoWithMDC("Bridge Connected") logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames) bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) { if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -251,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
} }
artemis(ArtemisState.STARTING) { artemis(ArtemisState.STARTING) {
val startedArtemis = artemis.started val startedArtemis = artemis!!.started
if (startedArtemis == null) { if (startedArtemis == null) {
logInfoWithMDC("Bridge Connected but Artemis is disconnected") logInfoWithMDC("Bridge Connected but Artemis is disconnected")
ArtemisState.STOPPED ArtemisState.STOPPED
@ -284,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected") logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false) amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
} }
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -416,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value properties[key] = value
} }
} }
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName) val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(), allowedRemoteLegalNames.first().toString(),
properties) properties)
sendableMessage.onComplete.then { sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -455,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
} }
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) { override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
lock.withLock { lock.withLock {
val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() } val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() }
@ -465,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize,
revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) } revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) }
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!, val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig)
bridgeMetricsService, bridgeConnectionTTLSeconds)
bridges += newBridge bridges += newBridge
bridgeMetricsService?.bridgeCreated(targets, legalNames) bridgeMetricsService?.bridgeCreated(targets, legalNames)
newBridge newBridge
@ -484,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName) queueNamesToBridgesMap.remove(queueName)
} }
bridge.stop() bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
} }
} }
} }
@ -495,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
// queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method. // queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method.
val bridges = queueNamesToBridgesMap[queueName]?.toList() val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map { bridges?.associate {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap() } ?: emptyMap()
} }
} }
override fun start() { override fun start() {
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS) sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY))
val artemis = artemisMessageClientFactory() sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge")
val artemis = artemisMessageClientFactory("ArtemisBridge")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
} }
@ -520,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
sharedEventLoopGroup = null sharedEventLoopGroup = null
queueNamesToBridgesMap.clear() queueNamesToBridgesMap.clear()
artemis?.stop() artemis?.stop()
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
} }
} }
} }

View File

@ -5,16 +5,13 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX
import net.corda.nodeapi.internal.ArtemisSessionProvider import net.corda.nodeapi.internal.ArtemisSessionProvider
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.crypto.x509 import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
@ -27,6 +24,7 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientSession import org.apache.activemq.artemis.api.core.client.ClientSession
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.time.Duration
import java.util.* import java.util.*
class BridgeControlListener(private val keyStore: CertificateStore, class BridgeControlListener(private val keyStore: CertificateStore,
@ -36,10 +34,10 @@ class BridgeControlListener(private val keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
bridgeMetricsService: BridgeMetricsService? = null, bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean = false, trace: Boolean = false,
sslHandshakeTimeout: Long? = null, sslHandshakeTimeout: Duration? = null,
bridgeConnectionTTLSeconds: Int = 0) : AutoCloseable { bridgeConnectionTTLSeconds: Int = 0) : AutoCloseable {
private val bridgeId: String = UUID.randomUUID().toString() private val bridgeId: String = UUID.randomUUID().toString()
private var bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId" private var bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId"
@ -57,13 +55,6 @@ class BridgeControlListener(private val keyStore: CertificateStore,
private var controlConsumer: ClientConsumer? = null private var controlConsumer: ClientConsumer? = null
private var notifyConsumer: ClientConsumer? = null private var notifyConsumer: ClientConsumer? = null
constructor(config: MutualSslConfiguration,
p2pAddress: NetworkHostAndPort,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
proxy: ProxyConfig? = null) : this(config.keyStore.get(), config.trustStore.get(), config.useOpenSsl, proxy, maxMessageSize, revocationConfig, enableSNI, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) })
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
} }
@ -88,7 +79,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId" bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId"
bridgeManager.start() bridgeManager.start()
val artemis = artemisMessageClientFactory() val artemis = artemisMessageClientFactory("BridgeControl")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
val artemisClient = artemis.started!! val artemisClient = artemis.started!!

View File

@ -23,6 +23,7 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.MDC import org.slf4j.MDC
import java.time.Duration
/** /**
* The LoopbackBridgeManager holds the list of independent LoopbackBridge objects that actively loopback messages to local Artemis * The LoopbackBridgeManager holds the list of independent LoopbackBridge objects that actively loopback messages to local Artemis
@ -36,11 +37,11 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null, private val bridgeMetricsService: BridgeMetricsService? = null,
private val isLocalInbox: (String) -> Boolean, private val isLocalInbox: (String) -> Boolean,
trace: Boolean, trace: Boolean,
sslHandshakeTimeout: Long? = null, sslHandshakeTimeout: Duration? = null,
bridgeConnectionTTLSeconds: Int = 0) : AMQPBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig, bridgeConnectionTTLSeconds: Int = 0) : AMQPBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig,
maxMessageSize, revocationConfig, enableSNI, maxMessageSize, revocationConfig, enableSNI,
artemisMessageClientFactory, bridgeMetricsService, artemisMessageClientFactory, bridgeMetricsService,
@ -203,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
override fun start() { override fun start() {
super.start() super.start()
val artemis = artemisMessageClientFactory() val artemis = artemisMessageClientFactory("LoopbackBridge")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
} }

View File

@ -1,16 +1,20 @@
package net.corda.nodeapi.internal.config package net.corda.nodeapi.internal.config
import net.corda.core.utilities.seconds
import java.time.Duration
interface SslConfiguration { interface SslConfiguration {
val keyStore: FileBasedCertificateStoreSupplier? val keyStore: FileBasedCertificateStoreSupplier?
val trustStore: FileBasedCertificateStoreSupplier? val trustStore: FileBasedCertificateStoreSupplier?
val useOpenSsl: Boolean val useOpenSsl: Boolean
val handshakeTimeout: Duration?
companion object { companion object {
fun mutual(keyStore: FileBasedCertificateStoreSupplier,
fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier): MutualSslConfiguration { trustStore: FileBasedCertificateStoreSupplier,
handshakeTimeout: Duration? = null): MutualSslConfiguration {
return MutualSslOptions(keyStore, trustStore) return MutualSslOptions(keyStore, trustStore, handshakeTimeout)
} }
} }
} }
@ -21,9 +25,10 @@ interface MutualSslConfiguration : SslConfiguration {
} }
private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier,
override val trustStore: FileBasedCertificateStoreSupplier) : MutualSslConfiguration { override val trustStore: FileBasedCertificateStoreSupplier,
override val handshakeTimeout: Duration?) : MutualSslConfiguration {
override val useOpenSsl: Boolean = false override val useOpenSsl: Boolean = false
} }
const val DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS = 60000L // Set at least 3 times higher than sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT which is 15 sec @Suppress("MagicNumber")
val DEFAULT_SSL_HANDSHAKE_TIMEOUT: Duration = 60.seconds // Set at least 3 times higher than sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT which is 15 sec

View File

@ -1,16 +1,41 @@
@file:Suppress("MagicNumber", "TooGenericExceptionCaught")
package net.corda.nodeapi.internal.crypto package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.* import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import org.bouncycastle.asn1.* import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -28,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.security.cert.* import java.security.cert.CertPath
import java.security.cert.Certificate import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
@ -355,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) { if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer)) GeneralNames(GeneralName(crlIssuer))
} }
@ -368,7 +398,6 @@ object X509Utilities {
} }
} }
@Suppress("MagicNumber")
private fun generateCertificateSerialNumber(): BigInteger { private fun generateCertificateSerialNumber(): BigInteger {
val bytes = ByteArray(CERTIFICATE_SERIAL_NUMBER_LENGTH) val bytes = ByteArray(CERTIFICATE_SERIAL_NUMBER_LENGTH)
newSecureRandom().nextBytes(bytes) newSecureRandom().nextBytes(bytes)
@ -376,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40) bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes) return BigInteger(bytes)
} }
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
} }
// Assuming cert type to role is 1:1 // Assuming cert type to role is 1:1
@ -408,6 +439,29 @@ fun PKCS10CertificationRequest.isSignatureValid(): Boolean {
return this.isSignatureValid(JcaContentVerifierProviderBuilder().build(this.subjectPublicKeyInfo)) return this.isSignatureValid(JcaContentVerifierProviderBuilder().build(this.subjectPublicKeyInfo))
} }
fun X509Certificate.toSimpleString(): String {
val bcCert = toBc()
val keyIdentifier = try {
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (e: Exception) {
"null"
}
val authorityKeyIdentifier = try {
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (e: Exception) {
"null"
}
val subject = bcCert.subject
val issuer = bcCert.issuer
val role = CertRole.extract(this)
return "$subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier] $role $serialNumber [${distributionPointsToString()}]"
}
fun X509CRL.toSimpleString(): String {
val revokedSerialNumbers = revokedCertificates?.map { it.serialNumber }
return "$issuerX500Principal ${thisUpdate.toInstant()} ${nextUpdate.toInstant()} ${revokedSerialNumbers ?: "[]"}"
}
/** /**
* Check certificate validity or print warning if expiry is within 30 days * Check certificate validity or print warning if expiry is within 30 days
*/ */
@ -438,6 +492,8 @@ class X509CertificateFactory {
fun generateCertPath(vararg certificates: X509Certificate): CertPath = generateCertPath(certificates.asList()) fun generateCertPath(vararg certificates: X509Certificate): CertPath = generateCertPath(certificates.asList())
fun generateCertPath(certificates: List<X509Certificate>): CertPath = delegate.generateCertPath(certificates) fun generateCertPath(certificates: List<X509Certificate>): CertPath = delegate.generateCertPath(certificates)
fun generateCRL(input: InputStream): X509CRL = delegate.generateCRL(input) as X509CRL
} }
enum class CertificateType(val keyUsage: KeyUsage, vararg val purposes: KeyPurposeId, val isCA: Boolean, val role: CertRole?) { enum class CertificateType(val keyUsage: KeyUsage, vararg val purposes: KeyPurposeId, val isCA: Boolean, val role: CertRole?) {

View File

@ -115,11 +115,10 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
val transport = connection.transport as ProtonJTransport val transport = connection.transport as ProtonJTransport
transport.protocolTracer = object : ProtocolTracer { transport.protocolTracer = object : ProtocolTracer {
override fun sentFrame(transportFrame: TransportFrame) { override fun sentFrame(transportFrame: TransportFrame) {
logInfoWithMDC { "${transportFrame.body}" } logInfoWithMDC { "sentFrame: ${transportFrame.body}" }
} }
override fun receivedFrame(transportFrame: TransportFrame) { override fun receivedFrame(transportFrame: TransportFrame) {
logInfoWithMDC { "${transportFrame.body}" } logInfoWithMDC { "receivedFrame: ${transportFrame.body}" }
} }
} }
} }
@ -186,7 +185,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
} }
} }
@Suppress("OverridingDeprecatedMember") @Deprecated("Deprecated in Java")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
logWarnWithMDC("Closing channel due to nonrecoverable exception ${cause.message}") logWarnWithMDC("Closing channel due to nonrecoverable exception ${cause.message}")
if (log.isTraceEnabled) { if (log.isTraceEnabled) {
@ -298,16 +297,15 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
cause is ClosedChannelException -> logWarnWithMDC("SSL Handshake closed early.") cause is ClosedChannelException -> logWarnWithMDC("SSL Handshake closed early.")
cause is SslHandshakeTimeoutException -> logWarnWithMDC("SSL Handshake timed out") cause is SslHandshakeTimeoutException -> logWarnWithMDC("SSL Handshake timed out")
// Sadly the exception thrown by Netty wrapper requires that we check the message. // Sadly the exception thrown by Netty wrapper requires that we check the message.
cause is SSLException && (cause.message?.contains("close_notify") == true) cause is SSLException && (cause.message?.contains("close_notify") == true) -> logWarnWithMDC("Received close_notify during handshake")
-> logWarnWithMDC("Received close_notify during handshake")
// io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure() // io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure()
cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!) cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!)
else -> badCert = true else -> badCert = true
} }
logWarnWithMDC("Handshake failure: ${evt.cause().message}")
if (log.isTraceEnabled) { if (log.isTraceEnabled) {
withMDC { log.trace("Handshake failure", evt.cause()) } withMDC { log.trace("Handshake failure", cause) }
} else {
logWarnWithMDC("Handshake failure: ${cause.message}")
} }
ctx.close() ctx.close()
} }

View File

@ -1,7 +1,11 @@
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.bootstrap.Bootstrap import io.netty.bootstrap.Bootstrap
import io.netty.channel.* import io.netty.channel.Channel
import io.netty.channel.ChannelFutureListener
import io.netty.channel.ChannelHandler
import io.netty.channel.ChannelInitializer
import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.nio.NioSocketChannel
@ -11,6 +15,7 @@ import io.netty.handler.proxy.HttpProxyHandler
import io.netty.handler.proxy.Socks4ProxyHandler import io.netty.handler.proxy.Socks4ProxyHandler
import io.netty.handler.proxy.Socks5ProxyHandler import io.netty.handler.proxy.Socks5ProxyHandler
import io.netty.resolver.NoopAddressResolverGroup import io.netty.resolver.NoopAddressResolverGroup
import io.netty.util.concurrent.DefaultThreadFactory
import io.netty.util.internal.logging.InternalLoggerFactory import io.netty.util.internal.logging.InternalLoggerFactory
import io.netty.util.internal.logging.Slf4JLoggerFactory import io.netty.util.internal.logging.Slf4JLoggerFactory
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
@ -22,14 +27,16 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.lang.Long.min import java.lang.Long.min
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util.concurrent.Executor
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
enum class ProxyVersion { enum class ProxyVersion {
@ -53,10 +60,11 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA
* otherwise it creates a self-contained Netty thraed pool and socket objects. * otherwise it creates a self-contained Netty thraed pool and socket objects.
* Once connected it can accept application packets to send via the AMQP protocol. * Once connected it can accept application packets to send via the AMQP protocol.
*/ */
class AMQPClient(val targets: List<NetworkHostAndPort>, class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration, private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null) : AutoCloseable { private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"),
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -75,7 +83,6 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
private val lock = ReentrantLock() private val lock = ReentrantLock()
@Volatile @Volatile
private var started: Boolean = false private var started: Boolean = false
private var workerGroup: EventLoopGroup? = null
@Volatile @Volatile
private var clientChannel: Channel? = null private var clientChannel: Channel? = null
// Offset into the list of targets, so that we can implement round-robin reconnect logic. // Offset into the list of targets, so that we can implement round-robin reconnect logic.
@ -109,14 +116,13 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER) retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER)
} }
private val connectListener = object : ChannelFutureListener { private val connectListener = ChannelFutureListener { future ->
override fun operationComplete(future: ChannelFuture) {
amqpActive = false amqpActive = false
if (!future.isSuccess) { if (!future.isSuccess) {
log.info("Failed to connect to $currentTarget", future.cause()) log.info("Failed to connect to $currentTarget", future.cause())
if (started) { if (started) {
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -128,7 +134,6 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
log.info("Connected to $currentTarget, Local address: $localAddressString") log.info("Connected to $currentTarget, Local address: $localAddressString")
} }
} }
}
private val closeListener = ChannelFutureListener { future -> private val closeListener = ChannelFutureListener { future ->
log.info("Disconnected from $currentTarget, Local address: $localAddressString") log.info("Disconnected from $currentTarget, Local address: $localAddressString")
@ -136,7 +141,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
clientChannel = null clientChannel = null
if (started && !amqpActive) { if (started && !amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" } log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" }
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -144,17 +149,16 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
} }
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() { private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
@Volatile @Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig))
}
@Suppress("ComplexMethod") @Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -194,14 +198,28 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget val target = parent.currentTarget
val handler = if (parent.configuration.useOpenSsl) { val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
parent.nettyThreading.sslDelegatedTaskExecutor
)
} else { } else {
createClientSslHelper(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
parent.nettyThreading.sslDelegatedTaskExecutor
)
} }
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler) pipeline.addLast("sslHandler", handler)
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
amqpChannelHandler = AMQPChannelHandler(false, amqpChannelHandler = AMQPChannelHandler(
false,
parent.allowedRemoteLegalNames, parent.allowedRemoteLegalNames,
// Single entry, key can be anything. // Single entry, key can be anything.
mapOf(DEFAULT to wrappedKeyManagerFactory), mapOf(DEFAULT to wrappedKeyManagerFactory),
@ -209,15 +227,24 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
conf.password, conf.password,
conf.trace, conf.trace,
false, false,
onOpen = { _, change -> onOpen = { _, change -> onChannelOpen(change) },
onClose = { _, change -> onChannelClose(change, target) },
onReceive = parent._onReceive::onNext
)
parent.amqpChannelHandler = amqpChannelHandler
pipeline.addLast(amqpChannelHandler)
}
private fun onChannelOpen(change: ConnectionChange) {
parent.run { parent.run {
amqpActive = true amqpActive = true
retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
_onConnection.onNext(change) _onConnection.onNext(change)
} }
}, }
onClose = { _, change ->
if (parent.amqpChannelHandler == amqpChannelHandler) { private fun onChannelClose(change: ConnectionChange, target: NetworkHostAndPort) {
if (parent.amqpChannelHandler != amqpChannelHandler) return
parent.run { parent.run {
_onConnection.onNext(change) _onConnection.onNext(change)
if (change.badCert) { if (change.badCert) {
@ -227,7 +254,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
if (started && amqpActive) { if (started && amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP active)" } log.debug { "Scheduling restart of $currentTarget (AMQP active)" }
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -235,11 +262,6 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
amqpActive = false amqpActive = false
} }
} }
},
onReceive = { rcv -> parent._onReceive.onNext(rcv) })
parent.amqpChannelHandler = amqpChannelHandler
pipeline.addLast(amqpChannelHandler)
}
} }
fun start() { fun start() {
@ -249,7 +271,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
return return
} }
log.info("Connect to: $currentTarget") log.info("Connect to: $currentTarget")
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS) (nettyThreading as? NettyThreading.NonShared)?.start()
started = true started = true
restart() restart()
} }
@ -261,7 +283,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
} }
val bootstrap = Bootstrap() val bootstrap = Bootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
// Delegate DNS Resolution to the proxy side, if we are using proxy. // Delegate DNS Resolution to the proxy side, if we are using proxy.
if (configuration.proxyConfig != null) { if (configuration.proxyConfig != null) {
bootstrap.resolver(NoopAddressResolverGroup.INSTANCE) bootstrap.resolver(NoopAddressResolverGroup.INSTANCE)
@ -275,14 +297,12 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
lock.withLock { lock.withLock {
log.info("Stopping connection to: $currentTarget, Local address: $localAddressString") log.info("Stopping connection to: $currentTarget, Local address: $localAddressString")
started = false started = false
if (sharedThreadPool == null) { if (nettyThreading is NettyThreading.NonShared) {
workerGroup?.shutdownGracefully() nettyThreading.stop()
workerGroup?.terminationFuture()?.sync()
} else { } else {
clientChannel?.close()?.sync() clientChannel?.close()?.sync()
} }
clientChannel = null clientChannel = null
workerGroup = null
log.info("Stopped connection to $currentTarget") log.info("Stopped connection to $currentTarget")
} }
} }
@ -323,4 +343,36 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized() private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange> val onConnection: Observable<ConnectionChange>
get() = _onConnection get() = _onConnection
sealed class NettyThreading {
abstract val eventLoopGroup: EventLoopGroup
abstract val sslDelegatedTaskExecutor: Executor
class Shared(override val eventLoopGroup: EventLoopGroup,
override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading()
class NonShared(val threadPoolName: String) : NettyThreading() {
private var _eventLoopGroup: NioEventLoopGroup? = null
override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup)
private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null
override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor)
fun start() {
check(_eventLoopGroup == null)
check(_sslDelegatedTaskExecutor == null)
_eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
_sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
}
fun stop() {
eventLoopGroup.shutdownGracefully()
eventLoopGroup.terminationFuture().sync()
sslDelegatedTaskExecutor.shutdown()
_eventLoopGroup = null
_sslDelegatedTaskExecutor = null
}
}
}
} }

View File

@ -2,7 +2,8 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.ArtemisMessagingComponent
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import java.time.Duration
interface AMQPConfiguration { interface AMQPConfiguration {
/** /**
@ -67,8 +68,8 @@ interface AMQPConfiguration {
get() = false get() = false
@JvmDefault @JvmDefault
val sslHandshakeTimeout: Long val sslHandshakeTimeout: Duration
get() = DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT get() = DEFAULT_SSL_HANDSHAKE_TIMEOUT // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT
/** /**
* An optional Health Check Phrase which if passed through the channel will cause AMQP Server to echo it back instead of doing normal pipeline processing * An optional Health Check Phrase which if passed through the channel will cause AMQP Server to echo it back instead of doing normal pipeline processing

View File

@ -11,6 +11,7 @@ import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler import io.netty.handler.logging.LoggingHandler
import io.netty.util.concurrent.DefaultThreadFactory
import io.netty.util.internal.logging.InternalLoggerFactory import io.netty.util.internal.logging.InternalLoggerFactory
import io.netty.util.internal.logging.Slf4JLoggerFactory import io.netty.util.internal.logging.Slf4JLoggerFactory
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
@ -20,15 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery import org.apache.qpid.proton.engine.Delivery
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.net.BindException import java.net.BindException
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
/** /**
@ -36,37 +37,35 @@ import kotlin.concurrent.withLock
*/ */
class AMQPServer(val hostName: String, class AMQPServer(val hostName: String,
val port: Int, val port: Int,
private val configuration: AMQPConfiguration) : AutoCloseable { private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
} }
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger() private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
} }
private val lock = ReentrantLock() private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null private var serverChannel: Channel? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>() private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() { private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig))
}
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -75,7 +74,8 @@ class AMQPServer(val hostName: String,
pipeline.addLast("sslHandler", sslHandler) pipeline.addLast("sslHandler", sslHandler)
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
val suppressLogs = ch.remoteAddress()?.hostString in amqpConfiguration.silencedIPs val suppressLogs = ch.remoteAddress()?.hostString in amqpConfiguration.silencedIPs
pipeline.addLast(AMQPChannelHandler(true, pipeline.addLast(AMQPChannelHandler(
true,
null, null,
// Passing a mapping of legal names to key managers to be able to pick the correct one after // Passing a mapping of legal names to key managers to be able to pick the correct one after
// SNI completion event is fired up. // SNI completion event is fired up.
@ -84,36 +84,42 @@ class AMQPServer(val hostName: String,
conf.password, conf.password,
conf.trace, conf.trace,
suppressLogs, suppressLogs,
onOpen = { channel, change -> onOpen = ::onChannelOpen,
onClose = ::onChannelClose,
onReceive = parent._onReceive::onNext
))
}
private fun onChannelOpen(channel: SocketChannel, change: ConnectionChange) {
parent.run { parent.run {
clientChannels[channel.remoteAddress()] = channel clientChannels[channel.remoteAddress()] = channel
_onConnection.onNext(change) _onConnection.onNext(change)
} }
}, }
onClose = { channel, change ->
private fun onChannelClose(channel: SocketChannel, change: ConnectionChange) {
parent.run { parent.run {
val remoteAddress = channel.remoteAddress() val remoteAddress = channel.remoteAddress()
clientChannels.remove(remoteAddress) clientChannels.remove(remoteAddress)
_onConnection.onNext(change) _onConnection.onNext(change)
} }
},
onReceive = { rcv -> parent._onReceive.onNext(rcv) }))
} }
private fun createSSLHandler(amqpConfig: AMQPConfiguration, ch: SocketChannel): Pair<ChannelHandler, Map<String, CertHoldingKeyManagerFactoryWrapper>> { private fun createSSLHandler(amqpConfig: AMQPConfiguration, ch: SocketChannel): Pair<ChannelHandler, Map<String, CertHoldingKeyManagerFactoryWrapper>> {
return if (amqpConfig.useOpenSsl && amqpConfig.enableSNI && amqpConfig.keyStore.aliases().size > 1) { return if (amqpConfig.useOpenSsl && amqpConfig.enableSNI && amqpConfig.keyStore.aliases().size > 1) {
val keyManagerFactoriesMap = splitKeystore(amqpConfig) val keyManagerFactoriesMap = splitKeystore(amqpConfig)
// SNI matching needed only when multiple nodes exist behind the server. // SNI matching needed only when multiple nodes exist behind the server.
Pair(createServerSNIOpenSslHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else { } else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) { val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else { } else {
// For javaSSL, SNI matching is handled at key manager level. // For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
} }
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory)) Pair(handler, mapOf(DEFAULT to keyManagerFactory))
} }
} }
@ -123,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock { lock.withLock {
stop() stop()
bossGroup = NioEventLoopGroup(1) sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap() val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -145,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() { fun stop() {
lock.withLock { lock.withLock {
try { serverChannel?.close()
stopping = true
serverChannel?.apply { close() }
serverChannel = null serverChannel = null
workerGroup?.shutdownGracefully() workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync() workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully() bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync() bossGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup = null bossGroup = null
} finally {
stopping = false sslDelegatedTaskExecutor?.shutdown()
} sslDelegatedTaskExecutor = null
} }
} }

View File

@ -11,7 +11,7 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
private val logger = LoggerFactory.getLogger(AllowAllRevocationChecker::class.java) private val logger = LoggerFactory.getLogger(AllowAllRevocationChecker::class.java)
override fun check(cert: Certificate?, unresolvedCritExts: MutableCollection<String>?) { override fun check(cert: Certificate, unresolvedCritExts: Collection<String>) {
logger.debug {"Passing certificate check for: $cert"} logger.debug {"Passing certificate check for: $cert"}
// Nothing to do // Nothing to do
} }
@ -20,7 +20,7 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
return true return true
} }
override fun getSupportedExtensions(): MutableSet<String>? { override fun getSupportedExtensions(): Set<String>? {
return null return null
} }
@ -28,7 +28,9 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
// Nothing to do // Nothing to do
} }
override fun getSoftFailExceptions(): MutableList<CertPathValidatorException> { override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return LinkedList() return Collections.emptyList()
} }
override fun clone(): AllowAllRevocationChecker = this
} }

View File

@ -3,10 +3,11 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import java.security.cert.X509CRL import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
interface ExternalCrlSource { @FunctionalInterface
interface CrlSource {
/** /**
* Given certificate provides a set of CRLs, potentially performing remote communication. * Given certificate provides a set of CRLs, potentially performing remote communication.
*/ */
fun fetch(certificate: X509Certificate) : Set<X509CRL> fun fetch(certificate: X509Certificate): Set<X509CRL>
} }

View File

@ -26,7 +26,7 @@ interface RevocationConfig {
/** /**
* CRLs are obtained from external source * CRLs are obtained from external source
* @see ExternalCrlSource * @see CrlSource
*/ */
EXTERNAL_SOURCE, EXTERNAL_SOURCE,
@ -39,14 +39,9 @@ interface RevocationConfig {
val mode: Mode val mode: Mode
/** /**
* Optional `ExternalCrlSource` which only makes sense with `mode` = `EXTERNAL_SOURCE` * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/ */
val externalCrlSource: ExternalCrlSource? val externalCrlSource: CrlSource?
/**
* Creates a copy of `RevocationConfig` with ExternalCrlSource enriched
*/
fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig
} }
/** /**
@ -54,16 +49,7 @@ interface RevocationConfig {
*/ */
fun Boolean.toRevocationConfig() = if(this) RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL) else RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL) fun Boolean.toRevocationConfig() = if(this) RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL) else RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)
data class RevocationConfigImpl(override val mode: RevocationConfig.Mode, override val externalCrlSource: ExternalCrlSource? = null) : RevocationConfig { data class RevocationConfigImpl(override val mode: RevocationConfig.Mode, override val externalCrlSource: CrlSource? = null) : RevocationConfig
override fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig {
return if(mode != RevocationConfig.Mode.EXTERNAL_SOURCE) {
this
} else {
assert(sourceFunc != null) { "There should be a way to obtain ExternalCrlSource" }
copy(externalCrlSource = sourceFunc!!())
}
}
}
class RevocationConfigParser : ConfigParser<RevocationConfig> { class RevocationConfigParser : ConfigParser<RevocationConfig> {
override fun parse(config: Config): RevocationConfig { override fun parse(config: Config): RevocationConfig {

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator import io.netty.buffer.ByteBufAllocator
@ -13,31 +15,41 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toHex import net.corda.core.utilities.debug
import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toBc import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509 import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.protonwrapper.netty.revocation.ExternalSourceRevocationChecker import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.DERIA5String import org.bouncycastle.asn1.DERIA5String
import org.bouncycastle.asn1.DEROctetString import org.bouncycastle.asn1.DEROctetString
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.net.Socket import java.net.Socket
import java.net.URI
import java.security.KeyStore import java.security.KeyStore
import java.security.cert.* import java.security.cert.CertificateException
import java.util.* import java.security.cert.PKIXBuilderParameters
import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate
import java.util.concurrent.Executor import java.util.concurrent.Executor
import javax.net.ssl.* import java.util.concurrent.ThreadPoolExecutor
import kotlin.system.measureTimeMillis import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal
private const val HOSTNAME_FORMAT = "%s.corda.net" private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default" internal const val DEFAULT = "default"
@ -46,65 +58,73 @@ internal const val DP_DEFAULT_ANSWER = "NO CRLDP ext"
internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.protonwrapper.netty.SSLHelper") internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.protonwrapper.netty.SSLHelper")
fun X509Certificate.distributionPoints() : Set<String>? { /**
logger.debug("Checking CRLDPs for $subjectX500Principal") * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" }
val crldpExtBytes = getExtensionValue(Extension.cRLDistributionPoints.id) val crldpExtBytes = getExtensionValue(Extension.cRLDistributionPoints.id)
if (crldpExtBytes == null) { if (crldpExtBytes == null) {
logger.debug(DP_DEFAULT_ANSWER) logger.debug(DP_DEFAULT_ANSWER)
return emptySet() return emptyMap()
} }
val derObjCrlDP = ASN1InputStream(ByteArrayInputStream(crldpExtBytes)).readObject() val derObjCrlDP = crldpExtBytes.toAsn1Object()
val dosCrlDP = derObjCrlDP as? DEROctetString val dosCrlDP = derObjCrlDP as? DEROctetString
if (dosCrlDP == null) { if (dosCrlDP == null) {
logger.error("Expected to have DEROctetString, actual type: ${derObjCrlDP.javaClass}") logger.error("Expected to have DEROctetString, actual type: ${derObjCrlDP.javaClass}")
return emptySet() return emptyMap()
} }
val crldpExtOctetsBytes = dosCrlDP.octets val dpObj = dosCrlDP.octets.toAsn1Object()
val dpObj = ASN1InputStream(ByteArrayInputStream(crldpExtOctetsBytes)).readObject() val crlDistPoint = CRLDistPoint.getInstance(dpObj)
val distPoint = CRLDistPoint.getInstance(dpObj) if (crlDistPoint == null) {
if (distPoint == null) {
logger.error("Could not instantiate CRLDistPoint, from: $dpObj") logger.error("Could not instantiate CRLDistPoint, from: $dpObj")
return emptySet() return emptyMap()
} }
val dpNames = distPoint.distributionPoints.mapNotNull { it.distributionPoint }.filter { it.type == DistributionPointName.FULL_NAME } val dpMap = HashMap<URI, List<X500Principal>?>()
val generalNames = dpNames.flatMap { GeneralNames.getInstance(it.name).names.asList() } for (distributionPoint in crlDistPoint.distributionPoints) {
return generalNames.filter { it.tagNo == GeneralName.uniformResourceIdentifier}.map { DERIA5String.getInstance(it.name).string }.toSet() val distributionPointName = distributionPoint.distributionPoint
} if (distributionPointName?.type != DistributionPointName.FULL_NAME) continue
val issuerNames = distributionPoint.crlIssuer?.names?.mapNotNull {
fun X509Certificate.distributionPointsToString() : String { if (it.tagNo == GeneralName.directoryName) {
return with(distributionPoints()) { X500Principal(X500Name.getInstance(it.name).encoded)
if(this == null || isEmpty()) {
DP_DEFAULT_ANSWER
} else { } else {
sorted().joinToString() null
} }
} }
for (generalName in GeneralNames.getInstance(distributionPointName.name).names) {
if (generalName.tagNo == GeneralName.uniformResourceIdentifier) {
val uri = URI(DERIA5String.getInstance(generalName.name).string)
dpMap[uri] = issuerNames
}
}
}
return dpMap
}
fun X509Certificate.distributionPointsToString(): String {
return with(distributionPoints().keys) {
if (isEmpty()) DP_DEFAULT_ANSWER else sorted().joinToString()
}
} }
fun ByteArray.toAsn1Object(): ASN1Primitive = ASN1InputStream(this).readObject()
fun certPathToString(certPath: Array<out X509Certificate>?): String { fun certPathToString(certPath: Array<out X509Certificate>?): String {
if (certPath == null) { if (certPath == null) {
return "<empty certpath>" return "<empty certpath>"
} }
val certs = certPath.map { return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
val bcCert = it.toBc() }
val subject = bcCert.subject.toString()
val issuer = bcCert.issuer.toString() /**
val keyIdentifier = try { * Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex() * which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
} catch (ex: Exception) { */
"null" fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
} return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
val authorityKeyIdentifier = try {
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (ex: Exception) {
"null"
}
" $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier] [${it.distributionPointsToString()}]"
}
return certs.joinToString("\r\n")
} }
@VisibleForTesting @VisibleForTesting
@ -117,7 +137,7 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
if (chain == null) { if (chain == null) {
return "<empty certpath>" return "<empty certpath>"
} }
return chain.map { it.toString() }.joinToString(", ") return chain.joinToString(", ") { it.toString() }
} }
private fun logErrors(chain: Array<out X509Certificate>?, block: () -> Unit) { private fun logErrors(chain: Array<out X509Certificate>?, block: () -> Unit) {
@ -169,37 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
} }
private object LoggingImmediateExecutor : Executor { internal fun createClientSslHandler(target: NetworkHostAndPort,
override fun execute(command: Runnable?) {
val log = LoggerFactory.getLogger(javaClass)
if (command == null) {
log.error("SSL handler executor called with a null command")
throw NullPointerException("command")
}
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHelper(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port) val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true sslEngine.useClientMode = true
@ -211,15 +205,15 @@ internal fun createClientSslHelper(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
@Suppress("DEPRECATION") return SslHandler(sslEngine, false, delegateTaskExecutor)
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
} }
internal fun createClientOpenSslHandler(target: NetworkHostAndPort, internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port) val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -229,13 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
@Suppress("DEPRECATION") return SslHandler(sslEngine, false, delegateTaskExecutor)
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
} }
internal fun createServerSslHandler(keyStore: CertificateStore, internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine() val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false sslEngine.useClientMode = false
@ -246,65 +240,34 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
@Suppress("DEPRECATION") return SslHandler(sslEngine, false, delegateTaskExecutor)
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
}
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java)
.map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, newSecureRandom())
return sslContext
}
@VisibleForTesting
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, revocationConfig: RevocationConfig): ManagerFactoryParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker // Custom PKIXRevocationChecker skipping CRL check
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
require(revocationConfig.externalCrlSource != null) { "externalCrlSource must not be null" }
ExternalSourceRevocationChecker(revocationConfig.externalCrlSource!!) { Date() } // Custom PKIXRevocationChecker which uses `externalCrlSource`
}
else -> {
val certPathBuilder = CertPathBuilder.getInstance("PKIX")
val pkixRevocationChecker = certPathBuilder.revocationChecker as PKIXRevocationChecker
pkixRevocationChecker.options = EnumSet.of(
// Prefer CRL over OCSP
PKIXRevocationChecker.Option.PREFER_CRLS,
// Don't fall back to OCSP checking
PKIXRevocationChecker.Option.NO_FALLBACK)
if (revocationConfig.mode == RevocationConfig.Mode.SOFT_FAIL) {
// Allow revocation check to succeed if the revocation status cannot be determined for one of
// the following reasons: The CRL or OCSP response cannot be obtained because of a network error.
pkixRevocationChecker.options = pkixRevocationChecker.options + PKIXRevocationChecker.Option.SOFT_FAIL
}
pkixRevocationChecker
}
}
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
} }
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc) val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false sslEngine.useClientMode = false
@Suppress("DEPRECATION") return SslHandler(sslEngine, false, delegateTaskExecutor)
return SslHandler(sslEngine, false, LoggingImmediateExecutor) }
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val trustManagers = trustManagerFactory
?.trustManagers
?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext
} }
/** /**
* Creates a special SNI handler used only when openSSL is used for AMQPServer * Creates a special SNI handler used only when openSSL is used for AMQPServer
*/ */
internal fun createServerSNIOpenSslHandler(keyManagerFactoriesMap: Map<String, KeyManagerFactory>, internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, KeyManagerFactory>,
trustManagerFactory: TrustManagerFactory): SniHandler { trustManagerFactory: TrustManagerFactory): SniHandler {
// Default value can be any in the map. // Default value can be any in the map.
val sslCtxBuilder = getServerSslContextBuilder(keyManagerFactoriesMap.values.first(), trustManagerFactory) val sslCtxBuilder = getServerSslContextBuilder(keyManagerFactoriesMap.values.first(), trustManagerFactory)
val mapping = DomainWildcardMappingBuilder(sslCtxBuilder.build()) val mapping = DomainWildcardMappingBuilder(sslCtxBuilder.build())
@ -314,20 +277,19 @@ internal fun createServerSNIOpenSslHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build()) return SniHandler(mapping.build())
} }
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder { private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory) return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL) .sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)) .trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE) .clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES) .ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()) .protocols(ArtemisTcpTransport.TLS_VERSIONS)
} }
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> { internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
val keyStore = config.keyStore.value.internal val keyStore = config.keyStore.value.internal
val password = config.keyStore.entryPassword.toCharArray() val password = config.keyStore.entryPassword.toCharArray()
return keyStore.aliases().toList().map { alias -> return keyStore.aliases().toList().associate { alias ->
val key = keyStore.getKey(alias, password) val key = keyStore.getKey(alias, password)
val certs = keyStore.getCertificateChain(alias) val certs = keyStore.getCertificateChain(alias)
val x500Name = keyStore.getCertificate(alias).x509.subjectX500Principal val x500Name = keyStore.getCertificate(alias).x509.subjectX500Principal
@ -338,14 +300,45 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
val newKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val newKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
newKeyManagerFactory.init(newKeyStore, password) newKeyManagerFactory.init(newKeyStore, password)
x500toHostName(cordaX500Name) to CertHoldingKeyManagerFactoryWrapper(newKeyManagerFactory, config) x500toHostName(cordaX500Name) to CertHoldingKeyManagerFactoryWrapper(newKeyManagerFactory, config)
}.toMap() }
} }
// As per Javadoc in: https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/KeyManagerFactory.html `init` method // As per Javadoc in: https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/KeyManagerFactory.html `init` method
// 2nd parameter `password` - the password for recovering keys in the KeyStore // 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray()) fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal) fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/** /**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -1,88 +0,0 @@
package net.corda.nodeapi.internal.protonwrapper.netty.revocation
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.protonwrapper.netty.ExternalCrlSource
import org.bouncycastle.asn1.x509.Extension
import java.security.cert.CRLReason
import java.security.cert.CertPathValidatorException
import java.security.cert.Certificate
import java.security.cert.CertificateRevokedException
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.*
/**
* Implementation of [PKIXRevocationChecker] which determines whether certificate is revoked using [externalCrlSource] which knows how to
* obtain a set of CRLs for a given certificate from an external source
*/
class ExternalSourceRevocationChecker(private val externalCrlSource: ExternalCrlSource, private val dateSource: () -> Date) : PKIXRevocationChecker() {
companion object {
private val logger = contextLogger()
}
override fun check(cert: Certificate, unresolvedCritExts: MutableCollection<String>?) {
val x509Certificate = cert as X509Certificate
checkApprovedCRLs(x509Certificate, externalCrlSource.fetch(x509Certificate))
}
/**
* Borrowed from `RevocationChecker.checkApprovedCRLs()`
*/
@Suppress("NestedBlockDepth")
@Throws(CertPathValidatorException::class)
private fun checkApprovedCRLs(cert: X509Certificate, approvedCRLs: Set<X509CRL>) {
// See if the cert is in the set of approved crls.
logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() cert SN: ${cert.serialNumber}")
for (crl in approvedCRLs) {
val entry = crl.getRevokedCertificate(cert)
if (entry != null) {
logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() CRL entry: $entry")
/*
* Abort CRL validation and throw exception if there are any
* unrecognized critical CRL entry extensions (see section
* 5.3 of RFC 5280).
*/
val unresCritExts = entry.criticalExtensionOIDs
if (unresCritExts != null && !unresCritExts.isEmpty()) {
/* remove any that we will process */
unresCritExts.remove(Extension.cRLDistributionPoints.id)
unresCritExts.remove(Extension.certificateIssuer.id)
if (!unresCritExts.isEmpty()) {
throw CertPathValidatorException(
"Unrecognized critical extension(s) in revoked CRL entry: $unresCritExts")
}
}
val reasonCode = entry.revocationReason ?: CRLReason.UNSPECIFIED
val revocationDate = entry.revocationDate
if (revocationDate.before(dateSource())) {
val t = CertificateRevokedException(
revocationDate, reasonCode,
crl.issuerX500Principal, mutableMapOf())
throw CertPathValidatorException(
t.message, t, null, -1, CertPathValidatorException.BasicReason.REVOKED)
}
}
}
}
override fun isForwardCheckingSupported(): Boolean {
return true
}
override fun getSupportedExtensions(): MutableSet<String>? {
return null
}
override fun init(forward: Boolean) {
// Nothing to do
}
override fun getSoftFailExceptions(): MutableList<CertPathValidatorException> {
return LinkedList()
}
}

Some files were not shown because too many files have changed in this diff Show More